mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-02 16:21:11 +00:00
Merge remote-tracking branch 'origin/tail-call-elimination' into gen-optional-field
This commit is contained in:
commit
3e12f1a309
12 changed files with 894 additions and 448 deletions
|
@ -627,16 +627,17 @@ impl<'a> Context<'a> {
|
|||
|
||||
let v_orig = v;
|
||||
|
||||
// NOTE deviation from lean, insert into local context
|
||||
let mut ctx = self.clone();
|
||||
ctx.local_context.join_points.insert(*j, (xs, v_orig));
|
||||
|
||||
let (v, v_live_vars) = {
|
||||
let ctx = self.update_var_info_with_params(xs);
|
||||
let ctx = ctx.update_var_info_with_params(xs);
|
||||
ctx.visit_stmt(v)
|
||||
};
|
||||
|
||||
let v = self.add_dec_for_dead_params(xs, v, &v_live_vars);
|
||||
let mut ctx = self.clone();
|
||||
|
||||
// NOTE deviation from lean, insert into local context
|
||||
ctx.local_context.join_points.insert(*j, (xs, v_orig));
|
||||
let v = ctx.add_dec_for_dead_params(xs, v, &v_live_vars);
|
||||
let mut ctx = ctx.clone();
|
||||
|
||||
update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars);
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ pub struct PartialProc<'a> {
|
|||
pub annotation: Variable,
|
||||
pub pattern_symbols: Vec<'a, Symbol>,
|
||||
pub body: roc_can::expr::Expr,
|
||||
pub is_tail_recursive: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
|
@ -39,6 +40,7 @@ pub struct Proc<'a> {
|
|||
pub body: Stmt<'a>,
|
||||
pub closes_over: Layout<'a>,
|
||||
pub ret_layout: Layout<'a>,
|
||||
pub is_tail_recursive: bool,
|
||||
}
|
||||
|
||||
impl<'a> Proc<'a> {
|
||||
|
@ -128,6 +130,7 @@ impl<'a> Procs<'a> {
|
|||
annotation: Variable,
|
||||
loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>,
|
||||
loc_body: Located<roc_can::expr::Expr>,
|
||||
is_tail_recursive: bool,
|
||||
ret_var: Variable,
|
||||
) {
|
||||
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
|
||||
|
@ -142,6 +145,7 @@ impl<'a> Procs<'a> {
|
|||
annotation,
|
||||
pattern_symbols,
|
||||
body: body.value,
|
||||
is_tail_recursive,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -174,6 +178,9 @@ impl<'a> Procs<'a> {
|
|||
ret_var: Variable,
|
||||
layout_cache: &mut LayoutCache<'a>,
|
||||
) -> Result<Layout<'a>, RuntimeError> {
|
||||
// anonymous functions cannot reference themselves, therefore cannot be tail-recursive
|
||||
let is_tail_recursive = false;
|
||||
|
||||
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
|
||||
Ok((pattern_vars, pattern_symbols, body)) => {
|
||||
// an anonymous closure. These will always be specialized already
|
||||
|
@ -212,6 +219,7 @@ impl<'a> Procs<'a> {
|
|||
annotation,
|
||||
pattern_symbols,
|
||||
body: body.value,
|
||||
is_tail_recursive,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -221,6 +229,7 @@ impl<'a> Procs<'a> {
|
|||
annotation,
|
||||
pattern_symbols,
|
||||
body: body.value,
|
||||
is_tail_recursive,
|
||||
};
|
||||
|
||||
// Mark this proc as in-progress, so if we're dealing with
|
||||
|
@ -370,7 +379,7 @@ impl<'a, 'i> Env<'a, 'i> {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Copy, Eq, Hash)]
|
||||
pub struct JoinPointId(Symbol);
|
||||
pub struct JoinPointId(pub Symbol);
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Param<'a> {
|
||||
|
@ -991,6 +1000,7 @@ fn specialize<'a>(
|
|||
annotation,
|
||||
pattern_symbols,
|
||||
body,
|
||||
is_tail_recursive,
|
||||
} = partial_proc;
|
||||
|
||||
// unify the called function with the specialized signature, then specialize the function body
|
||||
|
@ -999,9 +1009,7 @@ fn specialize<'a>(
|
|||
|
||||
debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_)));
|
||||
|
||||
let ret_symbol = env.unique_symbol();
|
||||
let hole = env.arena.alloc(Stmt::Ret(ret_symbol));
|
||||
let specialized_body = with_hole(env, body, procs, layout_cache, ret_symbol, hole);
|
||||
let specialized_body = from_can(env, body, procs, layout_cache);
|
||||
|
||||
// reset subs, so we don't get type errors when specializing for a different signature
|
||||
env.subs.rollback_to(snapshot);
|
||||
|
@ -1020,6 +1028,11 @@ fn specialize<'a>(
|
|||
proc_args.push((layout, *arg_name));
|
||||
}
|
||||
|
||||
let proc_args = proc_args.into_bump_slice();
|
||||
|
||||
let specialized_body =
|
||||
crate::tail_recursion::make_tail_recursive(env, proc_name, specialized_body, proc_args);
|
||||
|
||||
let ret_layout = layout_cache
|
||||
.from_var(&env.arena, ret_var, env.subs)
|
||||
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err));
|
||||
|
@ -1029,10 +1042,11 @@ fn specialize<'a>(
|
|||
|
||||
let proc = Proc {
|
||||
name: proc_name,
|
||||
args: proc_args.into_bump_slice(),
|
||||
args: proc_args,
|
||||
body: specialized_body,
|
||||
closes_over: closes_over_layout,
|
||||
ret_layout,
|
||||
is_tail_recursive,
|
||||
};
|
||||
|
||||
Ok(proc)
|
||||
|
@ -1088,30 +1102,27 @@ pub fn with_hole<'a>(
|
|||
},
|
||||
LetNonRec(def, cont, _, _) => {
|
||||
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
|
||||
if let Closure(_, _, _, _, _) = &def.loc_expr.value {
|
||||
// Now that we know for sure it's a closure, get an owned
|
||||
// version of these variant args so we can use them properly.
|
||||
match def.loc_expr.value {
|
||||
Closure(ann, _, _, loc_args, boxed_body) => {
|
||||
// Extract Procs, but discard the resulting Expr::Load.
|
||||
// That Load looks up the pointer, which we won't use here!
|
||||
if let Closure(ann, _, recursivity, loc_args, boxed_body) = def.loc_expr.value {
|
||||
// Extract Procs, but discard the resulting Expr::Load.
|
||||
// That Load looks up the pointer, which we won't use here!
|
||||
|
||||
let (loc_body, ret_var) = *boxed_body;
|
||||
let (loc_body, ret_var) = *boxed_body;
|
||||
|
||||
procs.insert_named(
|
||||
env,
|
||||
layout_cache,
|
||||
*symbol,
|
||||
ann,
|
||||
loc_args,
|
||||
loc_body,
|
||||
ret_var,
|
||||
);
|
||||
let is_tail_recursive =
|
||||
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
|
||||
|
||||
return with_hole(env, cont.value, procs, layout_cache, assigned, hole);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
procs.insert_named(
|
||||
env,
|
||||
layout_cache,
|
||||
*symbol,
|
||||
ann,
|
||||
loc_args,
|
||||
loc_body,
|
||||
is_tail_recursive,
|
||||
ret_var,
|
||||
);
|
||||
|
||||
return with_hole(env, cont.value, procs, layout_cache, assigned, hole);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1175,28 +1186,27 @@ pub fn with_hole<'a>(
|
|||
// because Roc is strict, only functions can be recursive!
|
||||
for def in defs.into_iter() {
|
||||
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
|
||||
// Now that we know for sure it's a closure, get an owned
|
||||
// version of these variant args so we can use them properly.
|
||||
match def.loc_expr.value {
|
||||
Closure(ann, _, _, loc_args, boxed_body) => {
|
||||
// Extract Procs, but discard the resulting Expr::Load.
|
||||
// That Load looks up the pointer, which we won't use here!
|
||||
if let Closure(ann, _, recursivity, loc_args, boxed_body) = def.loc_expr.value {
|
||||
// Extract Procs, but discard the resulting Expr::Load.
|
||||
// That Load looks up the pointer, which we won't use here!
|
||||
|
||||
let (loc_body, ret_var) = *boxed_body;
|
||||
let (loc_body, ret_var) = *boxed_body;
|
||||
|
||||
procs.insert_named(
|
||||
env,
|
||||
layout_cache,
|
||||
*symbol,
|
||||
ann,
|
||||
loc_args,
|
||||
loc_body,
|
||||
ret_var,
|
||||
);
|
||||
let is_tail_recursive =
|
||||
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
|
||||
|
||||
continue;
|
||||
}
|
||||
_ => unreachable!("recursive value is not a function"),
|
||||
procs.insert_named(
|
||||
env,
|
||||
layout_cache,
|
||||
*symbol,
|
||||
ann,
|
||||
loc_args,
|
||||
loc_body,
|
||||
is_tail_recursive,
|
||||
ret_var,
|
||||
);
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
unreachable!("recursive value does not have Identifier pattern")
|
||||
|
@ -1456,68 +1466,119 @@ pub fn with_hole<'a>(
|
|||
.from_var(env.arena, cond_var, env.subs)
|
||||
.expect("invalid cond_layout");
|
||||
|
||||
let assigned_in_jump = env.unique_symbol();
|
||||
let id = JoinPointId(env.unique_symbol());
|
||||
let jump = env
|
||||
.arena
|
||||
.alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump])));
|
||||
// if the hole is a return, then we don't need to merge the two
|
||||
// branches together again, we can just immediately return
|
||||
let is_terminated = matches!(hole, Stmt::Ret(_));
|
||||
|
||||
let mut stmt = with_hole(
|
||||
env,
|
||||
final_else.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
assigned_in_jump,
|
||||
jump,
|
||||
);
|
||||
if is_terminated {
|
||||
let terminator = hole;
|
||||
|
||||
for (loc_cond, loc_then) in branches.into_iter().rev() {
|
||||
let branching_symbol = env.unique_symbol();
|
||||
let then = with_hole(
|
||||
let mut stmt = with_hole(
|
||||
env,
|
||||
loc_then.value,
|
||||
final_else.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
assigned,
|
||||
terminator,
|
||||
);
|
||||
|
||||
for (loc_cond, loc_then) in branches.into_iter().rev() {
|
||||
let branching_symbol = env.unique_symbol();
|
||||
let then = with_hole(
|
||||
env,
|
||||
loc_then.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
assigned,
|
||||
terminator,
|
||||
);
|
||||
|
||||
stmt = Stmt::Cond {
|
||||
cond_symbol: branching_symbol,
|
||||
branching_symbol,
|
||||
cond_layout: cond_layout.clone(),
|
||||
branching_layout: cond_layout.clone(),
|
||||
pass: env.arena.alloc(then),
|
||||
fail: env.arena.alloc(stmt),
|
||||
ret_layout: ret_layout.clone(),
|
||||
};
|
||||
|
||||
// add condition
|
||||
stmt = with_hole(
|
||||
env,
|
||||
loc_cond.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
branching_symbol,
|
||||
env.arena.alloc(stmt),
|
||||
);
|
||||
}
|
||||
stmt
|
||||
} else {
|
||||
let assigned_in_jump = env.unique_symbol();
|
||||
let id = JoinPointId(env.unique_symbol());
|
||||
|
||||
let terminator = env
|
||||
.arena
|
||||
.alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump])));
|
||||
|
||||
let mut stmt = with_hole(
|
||||
env,
|
||||
final_else.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
assigned_in_jump,
|
||||
jump,
|
||||
terminator,
|
||||
);
|
||||
|
||||
stmt = Stmt::Cond {
|
||||
cond_symbol: branching_symbol,
|
||||
branching_symbol,
|
||||
cond_layout: cond_layout.clone(),
|
||||
branching_layout: cond_layout.clone(),
|
||||
pass: env.arena.alloc(then),
|
||||
fail: env.arena.alloc(stmt),
|
||||
ret_layout: ret_layout.clone(),
|
||||
for (loc_cond, loc_then) in branches.into_iter().rev() {
|
||||
let branching_symbol = env.unique_symbol();
|
||||
let then = with_hole(
|
||||
env,
|
||||
loc_then.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
assigned_in_jump,
|
||||
terminator,
|
||||
);
|
||||
|
||||
stmt = Stmt::Cond {
|
||||
cond_symbol: branching_symbol,
|
||||
branching_symbol,
|
||||
cond_layout: cond_layout.clone(),
|
||||
branching_layout: cond_layout.clone(),
|
||||
pass: env.arena.alloc(then),
|
||||
fail: env.arena.alloc(stmt),
|
||||
ret_layout: ret_layout.clone(),
|
||||
};
|
||||
|
||||
// add condition
|
||||
stmt = with_hole(
|
||||
env,
|
||||
loc_cond.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
branching_symbol,
|
||||
env.arena.alloc(stmt),
|
||||
);
|
||||
}
|
||||
|
||||
let layout = layout_cache
|
||||
.from_var(env.arena, branch_var, env.subs)
|
||||
.unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err));
|
||||
|
||||
let param = Param {
|
||||
symbol: assigned,
|
||||
layout,
|
||||
borrow: false,
|
||||
};
|
||||
|
||||
// add condition
|
||||
stmt = with_hole(
|
||||
env,
|
||||
loc_cond.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
branching_symbol,
|
||||
env.arena.alloc(stmt),
|
||||
);
|
||||
}
|
||||
|
||||
let layout = layout_cache
|
||||
.from_var(env.arena, branch_var, env.subs)
|
||||
.unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err));
|
||||
|
||||
let param = Param {
|
||||
symbol: assigned,
|
||||
layout,
|
||||
borrow: false,
|
||||
};
|
||||
|
||||
Stmt::Join {
|
||||
id,
|
||||
parameters: env.arena.alloc([param]),
|
||||
remainder: env.arena.alloc(stmt),
|
||||
continuation: hole,
|
||||
Stmt::Join {
|
||||
id,
|
||||
parameters: env.arena.alloc([param]),
|
||||
remainder: env.arena.alloc(stmt),
|
||||
continuation: hole,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1901,6 +1962,89 @@ pub fn from_can<'a>(
|
|||
use roc_can::expr::Expr::*;
|
||||
|
||||
match can_expr {
|
||||
When {
|
||||
cond_var,
|
||||
expr_var,
|
||||
region,
|
||||
loc_cond,
|
||||
branches,
|
||||
} => {
|
||||
let cond_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_cond.value {
|
||||
symbol
|
||||
} else {
|
||||
env.unique_symbol()
|
||||
};
|
||||
|
||||
let mut stmt = from_can_when(
|
||||
env,
|
||||
cond_var,
|
||||
expr_var,
|
||||
region,
|
||||
cond_symbol,
|
||||
branches,
|
||||
layout_cache,
|
||||
procs,
|
||||
None,
|
||||
);
|
||||
|
||||
// define the `when` condition
|
||||
if let roc_can::expr::Expr::Var(_) = loc_cond.value {
|
||||
// do nothing
|
||||
} else {
|
||||
stmt = with_hole(
|
||||
env,
|
||||
loc_cond.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
cond_symbol,
|
||||
env.arena.alloc(stmt),
|
||||
);
|
||||
};
|
||||
|
||||
stmt
|
||||
}
|
||||
If {
|
||||
cond_var,
|
||||
branch_var,
|
||||
branches,
|
||||
final_else,
|
||||
} => {
|
||||
let ret_layout = layout_cache
|
||||
.from_var(env.arena, branch_var, env.subs)
|
||||
.expect("invalid ret_layout");
|
||||
let cond_layout = layout_cache
|
||||
.from_var(env.arena, cond_var, env.subs)
|
||||
.expect("invalid cond_layout");
|
||||
|
||||
let mut stmt = from_can(env, final_else.value, procs, layout_cache);
|
||||
|
||||
for (loc_cond, loc_then) in branches.into_iter().rev() {
|
||||
let branching_symbol = env.unique_symbol();
|
||||
let then = from_can(env, loc_then.value, procs, layout_cache);
|
||||
|
||||
stmt = Stmt::Cond {
|
||||
cond_symbol: branching_symbol,
|
||||
branching_symbol,
|
||||
cond_layout: cond_layout.clone(),
|
||||
branching_layout: cond_layout.clone(),
|
||||
pass: env.arena.alloc(then),
|
||||
fail: env.arena.alloc(stmt),
|
||||
ret_layout: ret_layout.clone(),
|
||||
};
|
||||
|
||||
// add condition
|
||||
stmt = with_hole(
|
||||
env,
|
||||
loc_cond.value,
|
||||
procs,
|
||||
layout_cache,
|
||||
branching_symbol,
|
||||
env.arena.alloc(stmt),
|
||||
);
|
||||
}
|
||||
|
||||
stmt
|
||||
}
|
||||
LetRec(defs, cont, _, _) => {
|
||||
// because Roc is strict, only functions can be recursive!
|
||||
for def in defs.into_iter() {
|
||||
|
@ -1908,12 +2052,15 @@ pub fn from_can<'a>(
|
|||
// Now that we know for sure it's a closure, get an owned
|
||||
// version of these variant args so we can use them properly.
|
||||
match def.loc_expr.value {
|
||||
Closure(ann, _, _, loc_args, boxed_body) => {
|
||||
Closure(ann, _, recursivity, loc_args, boxed_body) => {
|
||||
// Extract Procs, but discard the resulting Expr::Load.
|
||||
// That Load looks up the pointer, which we won't use here!
|
||||
|
||||
let (loc_body, ret_var) = *boxed_body;
|
||||
|
||||
let is_tail_recursive =
|
||||
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
|
||||
|
||||
procs.insert_named(
|
||||
env,
|
||||
layout_cache,
|
||||
|
@ -1921,6 +2068,7 @@ pub fn from_can<'a>(
|
|||
ann,
|
||||
loc_args,
|
||||
loc_body,
|
||||
is_tail_recursive,
|
||||
ret_var,
|
||||
);
|
||||
|
||||
|
@ -1940,12 +2088,15 @@ pub fn from_can<'a>(
|
|||
// Now that we know for sure it's a closure, get an owned
|
||||
// version of these variant args so we can use them properly.
|
||||
match def.loc_expr.value {
|
||||
Closure(ann, _, _, loc_args, boxed_body) => {
|
||||
Closure(ann, _, recursivity, loc_args, boxed_body) => {
|
||||
// Extract Procs, but discard the resulting Expr::Load.
|
||||
// That Load looks up the pointer, which we won't use here!
|
||||
|
||||
let (loc_body, ret_var) = *boxed_body;
|
||||
|
||||
let is_tail_recursive =
|
||||
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
|
||||
|
||||
procs.insert_named(
|
||||
env,
|
||||
layout_cache,
|
||||
|
@ -1953,6 +2104,7 @@ pub fn from_can<'a>(
|
|||
ann,
|
||||
loc_args,
|
||||
loc_body,
|
||||
is_tail_recursive,
|
||||
ret_var,
|
||||
);
|
||||
|
||||
|
@ -2648,11 +2800,11 @@ fn store_pattern<'a>(
|
|||
}
|
||||
|
||||
Shadowed(_region, _ident) => {
|
||||
return Err(&"TODO");
|
||||
return Err(&"shadowed");
|
||||
}
|
||||
|
||||
UnsupportedPattern(_region) => {
|
||||
return Err(&"TODO");
|
||||
return Err(&"unsupported pattern");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
pub mod inc_dec;
|
||||
pub mod ir;
|
||||
pub mod layout;
|
||||
pub mod tail_recursion;
|
||||
|
||||
// Temporary, while we can build up test cases and optimize the exhaustiveness checking.
|
||||
// For now, following this warning's advice will lead to nasty type inference errors.
|
||||
|
|
201
compiler/mono/src/tail_recursion.rs
Normal file
201
compiler/mono/src/tail_recursion.rs
Normal file
|
@ -0,0 +1,201 @@
|
|||
use crate::ir::{CallType, Env, Expr, JoinPointId, Param, Stmt};
|
||||
use crate::layout::Layout;
|
||||
use bumpalo::collections::Vec;
|
||||
use bumpalo::Bump;
|
||||
use roc_module::symbol::Symbol;
|
||||
|
||||
pub fn make_tail_recursive<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
needle: Symbol,
|
||||
stmt: Stmt<'a>,
|
||||
args: &'a [(Layout<'a>, Symbol)],
|
||||
) -> Stmt<'a> {
|
||||
let id = JoinPointId(env.unique_symbol());
|
||||
|
||||
let alloced = env.arena.alloc(stmt);
|
||||
match insert_jumps(env.arena, alloced, id, needle) {
|
||||
None => alloced.clone(),
|
||||
Some(new) => {
|
||||
// jumps were inserted, we must now add a join point
|
||||
|
||||
let params = Vec::from_iter_in(
|
||||
args.iter().map(|(layout, symbol)| Param {
|
||||
symbol: *symbol,
|
||||
layout: layout.clone(),
|
||||
borrow: true,
|
||||
}),
|
||||
env.arena,
|
||||
)
|
||||
.into_bump_slice();
|
||||
|
||||
let args = Vec::from_iter_in(args.iter().map(|t| t.1), env.arena).into_bump_slice();
|
||||
|
||||
let jump = env.arena.alloc(Stmt::Jump(id, args));
|
||||
|
||||
Stmt::Join {
|
||||
id,
|
||||
remainder: jump,
|
||||
parameters: params,
|
||||
continuation: new,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_jumps<'a>(
|
||||
arena: &'a Bump,
|
||||
stmt: &'a Stmt<'a>,
|
||||
goal_id: JoinPointId,
|
||||
needle: Symbol,
|
||||
) -> Option<&'a Stmt<'a>> {
|
||||
use Stmt::*;
|
||||
|
||||
match stmt {
|
||||
Let(
|
||||
symbol,
|
||||
Expr::FunctionCall {
|
||||
call_type: CallType::ByName(fsym),
|
||||
args,
|
||||
..
|
||||
},
|
||||
_,
|
||||
Stmt::Ret(rsym),
|
||||
) if needle == *fsym && symbol == rsym => {
|
||||
// replace the call and return with a jump
|
||||
|
||||
let jump = Stmt::Jump(goal_id, args);
|
||||
|
||||
Some(arena.alloc(jump))
|
||||
}
|
||||
|
||||
Let(symbol, expr, layout, cont) => {
|
||||
let opt_cont = insert_jumps(arena, cont, goal_id, needle);
|
||||
|
||||
if opt_cont.is_some() {
|
||||
let cont = opt_cont.unwrap_or(cont);
|
||||
|
||||
Some(arena.alloc(Let(*symbol, expr.clone(), layout.clone(), cont)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Join {
|
||||
id,
|
||||
parameters,
|
||||
remainder,
|
||||
continuation,
|
||||
} => {
|
||||
let opt_remainder = insert_jumps(arena, remainder, goal_id, needle);
|
||||
let opt_continuation = insert_jumps(arena, continuation, goal_id, needle);
|
||||
|
||||
if opt_remainder.is_some() || opt_continuation.is_some() {
|
||||
let remainder = opt_remainder.unwrap_or(remainder);
|
||||
let continuation = opt_continuation.unwrap_or_else(|| *continuation);
|
||||
|
||||
Some(arena.alloc(Join {
|
||||
id: *id,
|
||||
parameters,
|
||||
remainder,
|
||||
continuation,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Cond {
|
||||
cond_symbol,
|
||||
cond_layout,
|
||||
branching_symbol,
|
||||
branching_layout,
|
||||
pass,
|
||||
fail,
|
||||
ret_layout,
|
||||
} => {
|
||||
let opt_pass = insert_jumps(arena, pass, goal_id, needle);
|
||||
let opt_fail = insert_jumps(arena, fail, goal_id, needle);
|
||||
|
||||
if opt_pass.is_some() || opt_fail.is_some() {
|
||||
let pass = opt_pass.unwrap_or(pass);
|
||||
let fail = opt_fail.unwrap_or_else(|| *fail);
|
||||
|
||||
Some(arena.alloc(Cond {
|
||||
cond_symbol: *cond_symbol,
|
||||
cond_layout: cond_layout.clone(),
|
||||
branching_symbol: *branching_symbol,
|
||||
branching_layout: branching_layout.clone(),
|
||||
pass,
|
||||
fail,
|
||||
ret_layout: ret_layout.clone(),
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Switch {
|
||||
cond_symbol,
|
||||
cond_layout,
|
||||
branches,
|
||||
default_branch,
|
||||
ret_layout,
|
||||
} => {
|
||||
let opt_default = insert_jumps(arena, default_branch, goal_id, needle);
|
||||
|
||||
let mut did_change = false;
|
||||
|
||||
let opt_branches = Vec::from_iter_in(
|
||||
branches.iter().map(|(label, branch)| {
|
||||
match insert_jumps(arena, branch, goal_id, needle) {
|
||||
None => None,
|
||||
Some(branch) => {
|
||||
did_change = true;
|
||||
Some((*label, branch.clone()))
|
||||
}
|
||||
}
|
||||
}),
|
||||
arena,
|
||||
);
|
||||
|
||||
if opt_default.is_some() || did_change {
|
||||
let default_branch = opt_default.unwrap_or(default_branch);
|
||||
|
||||
let branches = if did_change {
|
||||
let new = Vec::from_iter_in(
|
||||
opt_branches.into_iter().zip(branches.iter()).map(
|
||||
|(opt_branch, branch)| match opt_branch {
|
||||
None => branch.clone(),
|
||||
Some(new_branch) => new_branch,
|
||||
},
|
||||
),
|
||||
arena,
|
||||
);
|
||||
|
||||
new.into_bump_slice()
|
||||
} else {
|
||||
branches
|
||||
};
|
||||
|
||||
Some(arena.alloc(Switch {
|
||||
cond_symbol: *cond_symbol,
|
||||
cond_layout: cond_layout.clone(),
|
||||
default_branch,
|
||||
branches,
|
||||
ret_layout: ret_layout.clone(),
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Ret(_) => None,
|
||||
Inc(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
|
||||
Some(cont) => Some(arena.alloc(Inc(*symbol, cont))),
|
||||
None => None,
|
||||
},
|
||||
Dec(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
|
||||
Some(cont) => Some(arena.alloc(Dec(*symbol, cont))),
|
||||
None => None,
|
||||
},
|
||||
|
||||
Jump(_, _) => None,
|
||||
RuntimeError(_) => None,
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue