Merge remote-tracking branch 'origin/tail-call-elimination' into gen-optional-field

This commit is contained in:
Folkert 2020-08-13 00:21:21 +02:00
commit 3e12f1a309
12 changed files with 894 additions and 448 deletions

View file

@ -113,7 +113,7 @@ mod cli_run {
assert_eq!(&out.stderr, ""); assert_eq!(&out.stderr, "");
assert!(&out assert!(&out
.stdout .stdout
.ends_with("[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2]\n")); .ends_with("[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"));
assert!(out.status.success()); assert!(out.status.success());
} }
} }

View file

@ -105,7 +105,10 @@ pub fn gen(
Declare(def) | Builtin(def) => match def.loc_pattern.value { Declare(def) | Builtin(def) => match def.loc_pattern.value {
Identifier(symbol) => { Identifier(symbol) => {
match def.loc_expr.value { match def.loc_expr.value {
Closure(annotation, _, _, loc_args, boxed_body) => { Closure(annotation, _, recursivity, loc_args, boxed_body) => {
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
// If this is an exposed symbol, we need to // If this is an exposed symbol, we need to
@ -143,6 +146,7 @@ pub fn gen(
annotation, annotation,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
} }
@ -154,6 +158,7 @@ pub fn gen(
pattern_symbols: bumpalo::collections::Vec::new_in( pattern_symbols: bumpalo::collections::Vec::new_in(
mono_env.arena, mono_env.arena,
), ),
is_tail_recursive: false,
body, body,
}; };

View file

@ -184,20 +184,15 @@ pub fn construct_optimization_passes<'a>(
// function passes // function passes
fpm.add_cfg_simplification_pass();
mpm.add_cfg_simplification_pass();
fpm.add_jump_threading_pass();
mpm.add_jump_threading_pass();
fpm.add_memcpy_optimize_pass(); // this one is very important fpm.add_memcpy_optimize_pass(); // this one is very important
// In my testing, these don't do much for quicksort fpm.add_licm_pass();
// fpm.add_basic_alias_analysis_pass();
// fpm.add_jump_threading_pass();
// fpm.add_instruction_combining_pass();
// fpm.add_licm_pass();
// fpm.add_loop_unroll_pass();
// fpm.add_scalar_repl_aggregates_pass_ssa();
// fpm.add_cfg_simplification_pass();
// fpm.add_jump_threading_pass();
// module passes
// fpm.add_promote_memory_to_register_pass();
} }
} }
@ -699,7 +694,17 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
result result
} }
Ret(symbol) => load_symbol(env, scope, symbol), Ret(symbol) => {
let value = load_symbol(env, scope, symbol);
if let Some(block) = env.builder.get_insert_block() {
if block.get_terminator().is_none() {
env.builder.build_return(Some(&value));
}
}
value
}
Cond { Cond {
branching_symbol, branching_symbol,
@ -727,7 +732,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
&dyn inkwell::values::BasicValue<'_>, &dyn inkwell::values::BasicValue<'_>,
inkwell::basic_block::BasicBlock<'_>, inkwell::basic_block::BasicBlock<'_>,
)> = std::vec::Vec::with_capacity(2); )> = std::vec::Vec::with_capacity(2);
let cont_block = context.append_basic_block(parent, "branchcont"); let cont_block = context.append_basic_block(parent, "condbranchcont");
builder.build_conditional_branch(value, then_block, else_block); builder.build_conditional_branch(value, then_block, else_block);
@ -828,9 +833,6 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
// construct the blocks that may jump to this join point // construct the blocks that may jump to this join point
build_exp_stmt(env, layout_ids, scope, parent, remainder); build_exp_stmt(env, layout_ids, scope, parent, remainder);
// remove this join point again
scope.join_points.remove(&id);
for (ptr, param) in joinpoint_args.iter().zip(parameters.iter()) { for (ptr, param) in joinpoint_args.iter().zip(parameters.iter()) {
scope.insert(param.symbol, (param.layout.clone(), *ptr)); scope.insert(param.symbol, (param.layout.clone(), *ptr));
} }
@ -843,6 +845,9 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
// put the continuation in // put the continuation in
let result = build_exp_stmt(env, layout_ids, scope, parent, continuation); let result = build_exp_stmt(env, layout_ids, scope, parent, continuation);
// remove this join point again
scope.join_points.remove(&id);
cont_block.move_after(phi_block).unwrap(); cont_block.move_after(phi_block).unwrap();
result result
@ -1093,7 +1098,7 @@ fn decrement_refcount_list<'a, 'ctx, 'env>(
// build blocks // build blocks
let then_block = ctx.append_basic_block(parent, "then"); let then_block = ctx.append_basic_block(parent, "then");
let else_block = ctx.append_basic_block(parent, "else"); let else_block = ctx.append_basic_block(parent, "else");
let cont_block = ctx.append_basic_block(parent, "branchcont"); let cont_block = ctx.append_basic_block(parent, "dec_ref_branchcont");
builder.build_conditional_branch(comparison, then_block, else_block); builder.build_conditional_branch(comparison, then_block, else_block);
@ -1440,8 +1445,13 @@ pub fn build_proc<'a, 'ctx, 'env>(
let body = build_exp_stmt(env, layout_ids, &mut scope, fn_val, &proc.body); let body = build_exp_stmt(env, layout_ids, &mut scope, fn_val, &proc.body);
// only add a return if codegen did not already add one
if let Some(block) = builder.get_insert_block() {
if block.get_terminator().is_none() {
builder.build_return(Some(&body)); builder.build_return(Some(&body));
} }
}
}
pub fn verify_fn(fn_val: FunctionValue<'_>) { pub fn verify_fn(fn_val: FunctionValue<'_>) {
if !fn_val.verify(PRINT_FN_VERIFICATION_OUTPUT) { if !fn_val.verify(PRINT_FN_VERIFICATION_OUTPUT) {

View file

@ -265,9 +265,7 @@ mod gen_list {
assert_evals_to!( assert_evals_to!(
&format!("List.concat {} {}", slice_str1, slice_str2), &format!("List.concat {} {}", slice_str1, slice_str2),
expected_slice, expected_slice,
&'static [i64], &'static [i64]
|x| x,
true
); );
} }
@ -816,9 +814,7 @@ mod gen_list {
"# "#
), ),
&[4, 7, 19, 21], &[4, 7, 19, 21],
&'static [i64], &'static [i64]
|x| x,
true
); );
}) })
} }
@ -892,8 +888,6 @@ mod gen_list {
// ), // ),
// &[19, 7, 4, 21], // &[19, 7, 4, 21],
// &'static [i64], // &'static [i64],
// |x| x,
// true
// ); // );
// }) // })
// } // }
@ -967,8 +961,6 @@ mod gen_list {
// ), // ),
// 4, // 4,
// i64, // i64,
// |x| x,
// false
// ); // );
// }) // })
// } // }
@ -1020,9 +1012,7 @@ mod gen_list {
"# "#
), ),
&[1, 2, 3], &[1, 2, 3],
&'static [i64], &'static [i64]
|x| x,
true
); );
} }
@ -1041,9 +1031,7 @@ mod gen_list {
"# "#
), ),
&[0, 2, 3], &[0, 2, 3],
&'static [i64], &'static [i64]
|x| x,
true
); );
} }
@ -1060,9 +1048,7 @@ mod gen_list {
"# "#
), ),
&[1, 2, 3], &[1, 2, 3],
&'static [i64], &'static [i64]
|x| x,
true
); );
} }
} }

View file

@ -449,4 +449,25 @@ mod gen_primitives {
i64 i64
); );
} }
#[test]
fn factorial() {
assert_evals_to!(
indoc!(
r#"
factorial = \n, accum ->
when n is
0 ->
accum
_ ->
factorial (n - 1) (n * accum)
factorial 10 1
"#
),
3628800,
i64
);
}
} }

View file

@ -171,7 +171,8 @@ pub fn helper_without_uniqueness<'a>(
builder.position_at_end(basic_block); builder.position_at_end(basic_block);
let ret = roc_gen::llvm::build::build_exp_stmt( // builds the function body (return statement included)
roc_gen::llvm::build::build_exp_stmt(
&env, &env,
&mut layout_ids, &mut layout_ids,
&mut Scope::default(), &mut Scope::default(),
@ -179,8 +180,6 @@ pub fn helper_without_uniqueness<'a>(
&main_body, &main_body,
); );
builder.build_return(Some(&ret));
// Uncomment this to see the module's un-optimized LLVM instruction output: // Uncomment this to see the module's un-optimized LLVM instruction output:
// env.module.print_to_stderr(); // env.module.print_to_stderr();
@ -362,7 +361,8 @@ pub fn helper_with_uniqueness<'a>(
builder.position_at_end(basic_block); builder.position_at_end(basic_block);
let ret = roc_gen::llvm::build::build_exp_stmt( // builds the function body (return statement included)
roc_gen::llvm::build::build_exp_stmt(
&env, &env,
&mut layout_ids, &mut layout_ids,
&mut Scope::default(), &mut Scope::default(),
@ -370,8 +370,6 @@ pub fn helper_with_uniqueness<'a>(
&main_body, &main_body,
); );
builder.build_return(Some(&ret));
// you're in the version with uniqueness! // you're in the version with uniqueness!
// Uncomment this to see the module's un-optimized LLVM instruction output: // Uncomment this to see the module's un-optimized LLVM instruction output:

View file

@ -627,16 +627,17 @@ impl<'a> Context<'a> {
let v_orig = v; let v_orig = v;
// NOTE deviation from lean, insert into local context
let mut ctx = self.clone();
ctx.local_context.join_points.insert(*j, (xs, v_orig));
let (v, v_live_vars) = { let (v, v_live_vars) = {
let ctx = self.update_var_info_with_params(xs); let ctx = ctx.update_var_info_with_params(xs);
ctx.visit_stmt(v) ctx.visit_stmt(v)
}; };
let v = self.add_dec_for_dead_params(xs, v, &v_live_vars); let v = ctx.add_dec_for_dead_params(xs, v, &v_live_vars);
let mut ctx = self.clone(); let mut ctx = ctx.clone();
// NOTE deviation from lean, insert into local context
ctx.local_context.join_points.insert(*j, (xs, v_orig));
update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars); update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars);

View file

@ -23,6 +23,7 @@ pub struct PartialProc<'a> {
pub annotation: Variable, pub annotation: Variable,
pub pattern_symbols: Vec<'a, Symbol>, pub pattern_symbols: Vec<'a, Symbol>,
pub body: roc_can::expr::Expr, pub body: roc_can::expr::Expr,
pub is_tail_recursive: bool,
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -39,6 +40,7 @@ pub struct Proc<'a> {
pub body: Stmt<'a>, pub body: Stmt<'a>,
pub closes_over: Layout<'a>, pub closes_over: Layout<'a>,
pub ret_layout: Layout<'a>, pub ret_layout: Layout<'a>,
pub is_tail_recursive: bool,
} }
impl<'a> Proc<'a> { impl<'a> Proc<'a> {
@ -128,6 +130,7 @@ impl<'a> Procs<'a> {
annotation: Variable, annotation: Variable,
loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>, loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>,
loc_body: Located<roc_can::expr::Expr>, loc_body: Located<roc_can::expr::Expr>,
is_tail_recursive: bool,
ret_var: Variable, ret_var: Variable,
) { ) {
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
@ -142,6 +145,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive,
}, },
); );
} }
@ -174,6 +178,9 @@ impl<'a> Procs<'a> {
ret_var: Variable, ret_var: Variable,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
) -> Result<Layout<'a>, RuntimeError> { ) -> Result<Layout<'a>, RuntimeError> {
// anonymous functions cannot reference themselves, therefore cannot be tail-recursive
let is_tail_recursive = false;
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
Ok((pattern_vars, pattern_symbols, body)) => { Ok((pattern_vars, pattern_symbols, body)) => {
// an anonymous closure. These will always be specialized already // an anonymous closure. These will always be specialized already
@ -212,6 +219,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive,
}, },
); );
} }
@ -221,6 +229,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive,
}; };
// Mark this proc as in-progress, so if we're dealing with // Mark this proc as in-progress, so if we're dealing with
@ -370,7 +379,7 @@ impl<'a, 'i> Env<'a, 'i> {
} }
#[derive(Clone, Debug, PartialEq, Copy, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Copy, Eq, Hash)]
pub struct JoinPointId(Symbol); pub struct JoinPointId(pub Symbol);
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct Param<'a> { pub struct Param<'a> {
@ -991,6 +1000,7 @@ fn specialize<'a>(
annotation, annotation,
pattern_symbols, pattern_symbols,
body, body,
is_tail_recursive,
} = partial_proc; } = partial_proc;
// unify the called function with the specialized signature, then specialize the function body // unify the called function with the specialized signature, then specialize the function body
@ -999,9 +1009,7 @@ fn specialize<'a>(
debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_))); debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_)));
let ret_symbol = env.unique_symbol(); let specialized_body = from_can(env, body, procs, layout_cache);
let hole = env.arena.alloc(Stmt::Ret(ret_symbol));
let specialized_body = with_hole(env, body, procs, layout_cache, ret_symbol, hole);
// reset subs, so we don't get type errors when specializing for a different signature // reset subs, so we don't get type errors when specializing for a different signature
env.subs.rollback_to(snapshot); env.subs.rollback_to(snapshot);
@ -1020,6 +1028,11 @@ fn specialize<'a>(
proc_args.push((layout, *arg_name)); proc_args.push((layout, *arg_name));
} }
let proc_args = proc_args.into_bump_slice();
let specialized_body =
crate::tail_recursion::make_tail_recursive(env, proc_name, specialized_body, proc_args);
let ret_layout = layout_cache let ret_layout = layout_cache
.from_var(&env.arena, ret_var, env.subs) .from_var(&env.arena, ret_var, env.subs)
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err));
@ -1029,10 +1042,11 @@ fn specialize<'a>(
let proc = Proc { let proc = Proc {
name: proc_name, name: proc_name,
args: proc_args.into_bump_slice(), args: proc_args,
body: specialized_body, body: specialized_body,
closes_over: closes_over_layout, closes_over: closes_over_layout,
ret_layout, ret_layout,
is_tail_recursive,
}; };
Ok(proc) Ok(proc)
@ -1088,16 +1102,15 @@ pub fn with_hole<'a>(
}, },
LetNonRec(def, cont, _, _) => { LetNonRec(def, cont, _, _) => {
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
if let Closure(_, _, _, _, _) = &def.loc_expr.value { if let Closure(ann, _, recursivity, loc_args, boxed_body) = def.loc_expr.value {
// Now that we know for sure it's a closure, get an owned
// version of these variant args so we can use them properly.
match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -1105,14 +1118,12 @@ pub fn with_hole<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
return with_hole(env, cont.value, procs, layout_cache, assigned, hole); return with_hole(env, cont.value, procs, layout_cache, assigned, hole);
} }
_ => unreachable!(),
}
}
} }
if let roc_can::pattern::Pattern::Identifier(symbol) = def.loc_pattern.value { if let roc_can::pattern::Pattern::Identifier(symbol) = def.loc_pattern.value {
@ -1175,15 +1186,15 @@ pub fn with_hole<'a>(
// because Roc is strict, only functions can be recursive! // because Roc is strict, only functions can be recursive!
for def in defs.into_iter() { for def in defs.into_iter() {
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
// Now that we know for sure it's a closure, get an owned if let Closure(ann, _, recursivity, loc_args, boxed_body) = def.loc_expr.value {
// version of these variant args so we can use them properly.
match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -1191,13 +1202,12 @@ pub fn with_hole<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
continue; continue;
} }
_ => unreachable!("recursive value is not a function"),
}
} }
unreachable!("recursive value does not have Identifier pattern") unreachable!("recursive value does not have Identifier pattern")
} }
@ -1456,9 +1466,59 @@ pub fn with_hole<'a>(
.from_var(env.arena, cond_var, env.subs) .from_var(env.arena, cond_var, env.subs)
.expect("invalid cond_layout"); .expect("invalid cond_layout");
// if the hole is a return, then we don't need to merge the two
// branches together again, we can just immediately return
let is_terminated = matches!(hole, Stmt::Ret(_));
if is_terminated {
let terminator = hole;
let mut stmt = with_hole(
env,
final_else.value,
procs,
layout_cache,
assigned,
terminator,
);
for (loc_cond, loc_then) in branches.into_iter().rev() {
let branching_symbol = env.unique_symbol();
let then = with_hole(
env,
loc_then.value,
procs,
layout_cache,
assigned,
terminator,
);
stmt = Stmt::Cond {
cond_symbol: branching_symbol,
branching_symbol,
cond_layout: cond_layout.clone(),
branching_layout: cond_layout.clone(),
pass: env.arena.alloc(then),
fail: env.arena.alloc(stmt),
ret_layout: ret_layout.clone(),
};
// add condition
stmt = with_hole(
env,
loc_cond.value,
procs,
layout_cache,
branching_symbol,
env.arena.alloc(stmt),
);
}
stmt
} else {
let assigned_in_jump = env.unique_symbol(); let assigned_in_jump = env.unique_symbol();
let id = JoinPointId(env.unique_symbol()); let id = JoinPointId(env.unique_symbol());
let jump = env
let terminator = env
.arena .arena
.alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump]))); .alloc(Stmt::Jump(id, env.arena.alloc([assigned_in_jump])));
@ -1468,7 +1528,7 @@ pub fn with_hole<'a>(
procs, procs,
layout_cache, layout_cache,
assigned_in_jump, assigned_in_jump,
jump, terminator,
); );
for (loc_cond, loc_then) in branches.into_iter().rev() { for (loc_cond, loc_then) in branches.into_iter().rev() {
@ -1479,7 +1539,7 @@ pub fn with_hole<'a>(
procs, procs,
layout_cache, layout_cache,
assigned_in_jump, assigned_in_jump,
jump, terminator,
); );
stmt = Stmt::Cond { stmt = Stmt::Cond {
@ -1520,6 +1580,7 @@ pub fn with_hole<'a>(
continuation: hole, continuation: hole,
} }
} }
}
When { When {
cond_var, cond_var,
@ -1901,6 +1962,89 @@ pub fn from_can<'a>(
use roc_can::expr::Expr::*; use roc_can::expr::Expr::*;
match can_expr { match can_expr {
When {
cond_var,
expr_var,
region,
loc_cond,
branches,
} => {
let cond_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_cond.value {
symbol
} else {
env.unique_symbol()
};
let mut stmt = from_can_when(
env,
cond_var,
expr_var,
region,
cond_symbol,
branches,
layout_cache,
procs,
None,
);
// define the `when` condition
if let roc_can::expr::Expr::Var(_) = loc_cond.value {
// do nothing
} else {
stmt = with_hole(
env,
loc_cond.value,
procs,
layout_cache,
cond_symbol,
env.arena.alloc(stmt),
);
};
stmt
}
If {
cond_var,
branch_var,
branches,
final_else,
} => {
let ret_layout = layout_cache
.from_var(env.arena, branch_var, env.subs)
.expect("invalid ret_layout");
let cond_layout = layout_cache
.from_var(env.arena, cond_var, env.subs)
.expect("invalid cond_layout");
let mut stmt = from_can(env, final_else.value, procs, layout_cache);
for (loc_cond, loc_then) in branches.into_iter().rev() {
let branching_symbol = env.unique_symbol();
let then = from_can(env, loc_then.value, procs, layout_cache);
stmt = Stmt::Cond {
cond_symbol: branching_symbol,
branching_symbol,
cond_layout: cond_layout.clone(),
branching_layout: cond_layout.clone(),
pass: env.arena.alloc(then),
fail: env.arena.alloc(stmt),
ret_layout: ret_layout.clone(),
};
// add condition
stmt = with_hole(
env,
loc_cond.value,
procs,
layout_cache,
branching_symbol,
env.arena.alloc(stmt),
);
}
stmt
}
LetRec(defs, cont, _, _) => { LetRec(defs, cont, _, _) => {
// because Roc is strict, only functions can be recursive! // because Roc is strict, only functions can be recursive!
for def in defs.into_iter() { for def in defs.into_iter() {
@ -1908,12 +2052,15 @@ pub fn from_can<'a>(
// Now that we know for sure it's a closure, get an owned // Now that we know for sure it's a closure, get an owned
// version of these variant args so we can use them properly. // version of these variant args so we can use them properly.
match def.loc_expr.value { match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => { Closure(ann, _, recursivity, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -1921,6 +2068,7 @@ pub fn from_can<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
@ -1940,12 +2088,15 @@ pub fn from_can<'a>(
// Now that we know for sure it's a closure, get an owned // Now that we know for sure it's a closure, get an owned
// version of these variant args so we can use them properly. // version of these variant args so we can use them properly.
match def.loc_expr.value { match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => { Closure(ann, _, recursivity, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -1953,6 +2104,7 @@ pub fn from_can<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
@ -2648,11 +2800,11 @@ fn store_pattern<'a>(
} }
Shadowed(_region, _ident) => { Shadowed(_region, _ident) => {
return Err(&"TODO"); return Err(&"shadowed");
} }
UnsupportedPattern(_region) => { UnsupportedPattern(_region) => {
return Err(&"TODO"); return Err(&"unsupported pattern");
} }
} }

View file

@ -14,6 +14,7 @@
pub mod inc_dec; pub mod inc_dec;
pub mod ir; pub mod ir;
pub mod layout; pub mod layout;
pub mod tail_recursion;
// Temporary, while we can build up test cases and optimize the exhaustiveness checking. // Temporary, while we can build up test cases and optimize the exhaustiveness checking.
// For now, following this warning's advice will lead to nasty type inference errors. // For now, following this warning's advice will lead to nasty type inference errors.

View file

@ -0,0 +1,201 @@
use crate::ir::{CallType, Env, Expr, JoinPointId, Param, Stmt};
use crate::layout::Layout;
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_module::symbol::Symbol;
pub fn make_tail_recursive<'a>(
env: &mut Env<'a, '_>,
needle: Symbol,
stmt: Stmt<'a>,
args: &'a [(Layout<'a>, Symbol)],
) -> Stmt<'a> {
let id = JoinPointId(env.unique_symbol());
let alloced = env.arena.alloc(stmt);
match insert_jumps(env.arena, alloced, id, needle) {
None => alloced.clone(),
Some(new) => {
// jumps were inserted, we must now add a join point
let params = Vec::from_iter_in(
args.iter().map(|(layout, symbol)| Param {
symbol: *symbol,
layout: layout.clone(),
borrow: true,
}),
env.arena,
)
.into_bump_slice();
let args = Vec::from_iter_in(args.iter().map(|t| t.1), env.arena).into_bump_slice();
let jump = env.arena.alloc(Stmt::Jump(id, args));
Stmt::Join {
id,
remainder: jump,
parameters: params,
continuation: new,
}
}
}
}
fn insert_jumps<'a>(
arena: &'a Bump,
stmt: &'a Stmt<'a>,
goal_id: JoinPointId,
needle: Symbol,
) -> Option<&'a Stmt<'a>> {
use Stmt::*;
match stmt {
Let(
symbol,
Expr::FunctionCall {
call_type: CallType::ByName(fsym),
args,
..
},
_,
Stmt::Ret(rsym),
) if needle == *fsym && symbol == rsym => {
// replace the call and return with a jump
let jump = Stmt::Jump(goal_id, args);
Some(arena.alloc(jump))
}
Let(symbol, expr, layout, cont) => {
let opt_cont = insert_jumps(arena, cont, goal_id, needle);
if opt_cont.is_some() {
let cont = opt_cont.unwrap_or(cont);
Some(arena.alloc(Let(*symbol, expr.clone(), layout.clone(), cont)))
} else {
None
}
}
Join {
id,
parameters,
remainder,
continuation,
} => {
let opt_remainder = insert_jumps(arena, remainder, goal_id, needle);
let opt_continuation = insert_jumps(arena, continuation, goal_id, needle);
if opt_remainder.is_some() || opt_continuation.is_some() {
let remainder = opt_remainder.unwrap_or(remainder);
let continuation = opt_continuation.unwrap_or_else(|| *continuation);
Some(arena.alloc(Join {
id: *id,
parameters,
remainder,
continuation,
}))
} else {
None
}
}
Cond {
cond_symbol,
cond_layout,
branching_symbol,
branching_layout,
pass,
fail,
ret_layout,
} => {
let opt_pass = insert_jumps(arena, pass, goal_id, needle);
let opt_fail = insert_jumps(arena, fail, goal_id, needle);
if opt_pass.is_some() || opt_fail.is_some() {
let pass = opt_pass.unwrap_or(pass);
let fail = opt_fail.unwrap_or_else(|| *fail);
Some(arena.alloc(Cond {
cond_symbol: *cond_symbol,
cond_layout: cond_layout.clone(),
branching_symbol: *branching_symbol,
branching_layout: branching_layout.clone(),
pass,
fail,
ret_layout: ret_layout.clone(),
}))
} else {
None
}
}
Switch {
cond_symbol,
cond_layout,
branches,
default_branch,
ret_layout,
} => {
let opt_default = insert_jumps(arena, default_branch, goal_id, needle);
let mut did_change = false;
let opt_branches = Vec::from_iter_in(
branches.iter().map(|(label, branch)| {
match insert_jumps(arena, branch, goal_id, needle) {
None => None,
Some(branch) => {
did_change = true;
Some((*label, branch.clone()))
}
}
}),
arena,
);
if opt_default.is_some() || did_change {
let default_branch = opt_default.unwrap_or(default_branch);
let branches = if did_change {
let new = Vec::from_iter_in(
opt_branches.into_iter().zip(branches.iter()).map(
|(opt_branch, branch)| match opt_branch {
None => branch.clone(),
Some(new_branch) => new_branch,
},
),
arena,
);
new.into_bump_slice()
} else {
branches
};
Some(arena.alloc(Switch {
cond_symbol: *cond_symbol,
cond_layout: cond_layout.clone(),
default_branch,
branches,
ret_layout: ret_layout.clone(),
}))
} else {
None
}
}
Ret(_) => None,
Inc(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Inc(*symbol, cont))),
None => None,
},
Dec(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Dec(*symbol, cont))),
None => None,
},
Jump(_, _) => None,
RuntimeError(_) => None,
}
}

View file

@ -130,59 +130,6 @@ mod test_mono {
} }
#[test] #[test]
fn ir_if() {
compiles_to_ir(
r#"
if True then 1 else 2
"#,
indoc!(
r#"
let Test.3 = true;
if Test.3 then
let Test.1 = 1i64;
jump Test.2 Test.1;
else
let Test.1 = 2i64;
jump Test.2 Test.1;
joinpoint Test.2 Test.0:
ret Test.0;
"#
),
)
}
#[test]
fn ir_when_enum() {
compiles_to_ir(
r#"
when Blue is
Red -> 1
White -> 2
Blue -> 3
"#,
indoc!(
r#"
let Test.1 = 0u8;
switch Test.1:
case 1:
let Test.3 = 1i64;
jump Test.2 Test.3;
case 2:
let Test.4 = 2i64;
jump Test.2 Test.4;
default:
let Test.5 = 3i64;
jump Test.2 Test.5;
joinpoint Test.2 Test.0:
ret Test.0;
"#
),
)
}
#[test] #[test]
fn ir_when_maybe() { fn ir_when_maybe() {
compiles_to_ir( compiles_to_ir(
@ -193,22 +140,20 @@ mod test_mono {
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.11 = 0i64; let Test.9 = 0i64;
let Test.12 = 3i64; let Test.10 = 3i64;
let Test.2 = Just Test.11 Test.12; let Test.1 = Just Test.9 Test.10;
let Test.7 = true; let Test.5 = true;
let Test.9 = Index 0 Test.2; let Test.7 = Index 0 Test.1;
let Test.8 = 0i64; let Test.6 = 0i64;
let Test.10 = lowlevel Eq Test.8 Test.9; let Test.8 = lowlevel Eq Test.6 Test.7;
let Test.6 = lowlevel And Test.10 Test.7; let Test.4 = lowlevel And Test.8 Test.5;
if Test.6 then if Test.4 then
let Test.0 = Index 1 Test.2; let Test.0 = Index 1 Test.1;
jump Test.3 Test.0; ret Test.0;
else else
let Test.5 = 0i64; let Test.3 = 0i64;
jump Test.3 Test.5; ret Test.3;
joinpoint Test.3 Test.1:
ret Test.1;
"# "#
), ),
) )
@ -225,25 +170,23 @@ mod test_mono {
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.9 = 1i64; let Test.7 = 1i64;
let Test.10 = 1i64; let Test.8 = 1i64;
let Test.11 = 2i64; let Test.9 = 2i64;
let Test.4 = These Test.9 Test.10 Test.11; let Test.3 = These Test.7 Test.8 Test.9;
switch Test.4: switch Test.3:
case 2: case 2:
let Test.0 = Index 1 Test.4; let Test.0 = Index 1 Test.3;
jump Test.5 Test.0; ret Test.0;
case 0: case 0:
let Test.1 = Index 1 Test.4; let Test.1 = Index 1 Test.3;
jump Test.5 Test.1; ret Test.1;
default: default:
let Test.2 = Index 1 Test.4; let Test.2 = Index 1 Test.3;
jump Test.5 Test.2; ret Test.2;
joinpoint Test.5 Test.3:
ret Test.3;
"# "#
), ),
) )
@ -258,13 +201,11 @@ mod test_mono {
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.6 = 1i64; let Test.4 = 1i64;
let Test.7 = 3.14f64; let Test.5 = 3.14f64;
let Test.2 = Struct {Test.6, Test.7}; let Test.1 = Struct {Test.4, Test.5};
let Test.0 = Index 0 Test.2; let Test.0 = Index 0 Test.1;
jump Test.3 Test.0; ret Test.0;
joinpoint Test.3 Test.1:
ret Test.1;
"# "#
), ),
) )
@ -322,37 +263,33 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure Num.32 (#Attr.2, #Attr.3): procedure Num.32 (#Attr.2, #Attr.3):
let Test.21 = 0i64; let Test.18 = 0i64;
let Test.18 = lowlevel NotEq #Attr.3 Test.21; let Test.14 = lowlevel NotEq #Attr.3 Test.18;
if Test.18 then if Test.14 then
let Test.19 = 1i64; let Test.16 = 1i64;
let Test.20 = lowlevel NumDivUnchecked #Attr.2 #Attr.3; let Test.17 = lowlevel NumDivUnchecked #Attr.2 #Attr.3;
let Test.14 = Ok Test.19 Test.20; let Test.15 = Ok Test.16 Test.17;
jump Test.15 Test.14; ret Test.15;
else else
let Test.16 = 0i64; let Test.12 = 0i64;
let Test.17 = Struct {}; let Test.13 = Struct {};
let Test.14 = Err Test.16 Test.17; let Test.11 = Err Test.12 Test.13;
jump Test.15 Test.14; ret Test.11;
joinpoint Test.15 Test.13:
ret Test.13;
let Test.11 = 1000i64; let Test.9 = 1000i64;
let Test.12 = 10i64; let Test.10 = 10i64;
let Test.2 = CallByName Num.32 Test.11 Test.12; let Test.1 = CallByName Num.32 Test.9 Test.10;
let Test.7 = true; let Test.5 = true;
let Test.9 = Index 0 Test.2; let Test.7 = Index 0 Test.1;
let Test.8 = 1i64; let Test.6 = 1i64;
let Test.10 = lowlevel Eq Test.8 Test.9; let Test.8 = lowlevel Eq Test.6 Test.7;
let Test.6 = lowlevel And Test.10 Test.7; let Test.4 = lowlevel And Test.8 Test.5;
if Test.6 then if Test.4 then
let Test.0 = Index 1 Test.2; let Test.0 = Index 1 Test.1;
jump Test.3 Test.0; ret Test.0;
else else
let Test.5 = -1i64; let Test.3 = -1i64;
jump Test.3 Test.5; ret Test.3;
joinpoint Test.3 Test.1:
ret Test.1;
"# "#
), ),
) )
@ -396,27 +333,25 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure Num.14 (#Attr.2, #Attr.3): procedure Num.14 (#Attr.2, #Attr.3):
let Test.14 = lowlevel NumAdd #Attr.2 #Attr.3; let Test.12 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.14; ret Test.12;
let Test.12 = 0i64; let Test.10 = 0i64;
let Test.13 = 41i64; let Test.11 = 41i64;
let Test.0 = Just Test.12 Test.13; let Test.0 = Just Test.10 Test.11;
let Test.8 = true; let Test.6 = true;
let Test.10 = Index 0 Test.0; let Test.8 = Index 0 Test.0;
let Test.9 = 0i64; let Test.7 = 0i64;
let Test.11 = lowlevel Eq Test.9 Test.10; let Test.9 = lowlevel Eq Test.7 Test.8;
let Test.7 = lowlevel And Test.11 Test.8; let Test.5 = lowlevel And Test.9 Test.6;
if Test.7 then if Test.5 then
let Test.1 = Index 1 Test.0; let Test.1 = Index 1 Test.0;
let Test.5 = 1i64; let Test.3 = 1i64;
let Test.4 = CallByName Num.14 Test.1 Test.5; let Test.2 = CallByName Num.14 Test.1 Test.3;
jump Test.3 Test.4;
else
let Test.6 = 1i64;
jump Test.3 Test.6;
joinpoint Test.3 Test.2:
ret Test.2; ret Test.2;
else
let Test.4 = 1i64;
ret Test.4;
"# "#
), ),
) )
@ -441,31 +376,6 @@ mod test_mono {
) )
} }
#[test]
fn join_points() {
compiles_to_ir(
r#"
x =
if True then 1 else 2
x
"#,
indoc!(
r#"
let Test.4 = true;
if Test.4 then
let Test.2 = 1i64;
jump Test.3 Test.2;
else
let Test.2 = 2i64;
jump Test.3 Test.2;
joinpoint Test.3 Test.0:
ret Test.0;
"#
),
)
}
#[test] #[test]
fn guard_pattern_true() { fn guard_pattern_true() {
compiles_to_ir( compiles_to_ir(
@ -476,23 +386,21 @@ mod test_mono {
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.1 = 2i64; let Test.0 = 2i64;
let Test.8 = true; let Test.6 = true;
let Test.9 = 2i64; let Test.7 = 2i64;
let Test.12 = lowlevel Eq Test.9 Test.1; let Test.10 = lowlevel Eq Test.7 Test.0;
let Test.10 = lowlevel And Test.12 Test.8; let Test.8 = lowlevel And Test.10 Test.6;
let Test.5 = false; let Test.3 = false;
jump Test.4 Test.5;
joinpoint Test.4 Test.11:
let Test.7 = lowlevel And Test.11 Test.10;
if Test.7 then
let Test.3 = 42i64;
jump Test.2 Test.3; jump Test.2 Test.3;
joinpoint Test.2 Test.9:
let Test.5 = lowlevel And Test.9 Test.8;
if Test.5 then
let Test.1 = 42i64;
ret Test.1;
else else
let Test.6 = 0i64; let Test.4 = 0i64;
jump Test.2 Test.6; ret Test.4;
joinpoint Test.2 Test.0:
ret Test.0;
"# "#
), ),
) )
@ -508,17 +416,15 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure Num.14 (#Attr.2, #Attr.3): procedure Num.14 (#Attr.2, #Attr.3):
let Test.7 = lowlevel NumAdd #Attr.2 #Attr.3; let Test.5 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.7; ret Test.5;
let Test.6 = 2i64; let Test.4 = 2i64;
let Test.2 = Struct {Test.6}; let Test.1 = Struct {Test.4};
let Test.0 = Index 0 Test.2; let Test.0 = Index 0 Test.1;
let Test.5 = 3i64; let Test.3 = 3i64;
let Test.4 = CallByName Num.14 Test.0 Test.5; let Test.2 = CallByName Num.14 Test.0 Test.3;
jump Test.3 Test.4; ret Test.2;
joinpoint Test.3 Test.1:
ret Test.1;
"# "#
), ),
) )
@ -540,34 +446,32 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure Num.14 (#Attr.2, #Attr.3): procedure Num.14 (#Attr.2, #Attr.3):
let Test.22 = lowlevel NumAdd #Attr.2 #Attr.3; let Test.20 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.22; ret Test.20;
let Test.16 = 0i64;
let Test.18 = 0i64; let Test.18 = 0i64;
let Test.20 = 0i64; let Test.19 = 41i64;
let Test.21 = 41i64; let Test.17 = Just Test.18 Test.19;
let Test.19 = Just Test.20 Test.21; let Test.1 = Just Test.16 Test.17;
let Test.1 = Just Test.18 Test.19; let Test.8 = true;
let Test.10 = true; let Test.10 = Index 0 Test.1;
let Test.9 = 0i64;
let Test.15 = lowlevel Eq Test.9 Test.10;
let Test.13 = lowlevel And Test.15 Test.8;
let Test.12 = Index 0 Test.1; let Test.12 = Index 0 Test.1;
let Test.11 = 0i64; let Test.11 = 0i64;
let Test.17 = lowlevel Eq Test.11 Test.12; let Test.14 = lowlevel Eq Test.11 Test.12;
let Test.15 = lowlevel And Test.17 Test.10; let Test.7 = lowlevel And Test.14 Test.13;
let Test.14 = Index 0 Test.1; if Test.7 then
let Test.13 = 0i64; let Test.5 = Index 1 Test.1;
let Test.16 = lowlevel Eq Test.13 Test.14; let Test.2 = Index 1 Test.5;
let Test.9 = lowlevel And Test.16 Test.15; let Test.4 = 1i64;
if Test.9 then let Test.3 = CallByName Num.14 Test.2 Test.4;
let Test.7 = Index 1 Test.1;
let Test.2 = Index 1 Test.7;
let Test.6 = 1i64;
let Test.5 = CallByName Num.14 Test.2 Test.6;
jump Test.4 Test.5;
else
let Test.8 = 1i64;
jump Test.4 Test.8;
joinpoint Test.4 Test.3:
ret Test.3; ret Test.3;
else
let Test.6 = 1i64;
ret Test.6;
"# "#
), ),
) )
@ -584,31 +488,29 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure Num.14 (#Attr.2, #Attr.3): procedure Num.14 (#Attr.2, #Attr.3):
let Test.18 = lowlevel NumAdd #Attr.2 #Attr.3; let Test.16 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.18; ret Test.16;
let Test.16 = 2i64; let Test.14 = 2i64;
let Test.17 = 3i64; let Test.15 = 3i64;
let Test.3 = Struct {Test.16, Test.17}; let Test.2 = Struct {Test.14, Test.15};
let Test.8 = true; let Test.6 = true;
let Test.9 = 4i64; let Test.7 = 4i64;
let Test.10 = Index 0 Test.3; let Test.8 = Index 0 Test.2;
let Test.15 = lowlevel Eq Test.9 Test.10; let Test.13 = lowlevel Eq Test.7 Test.8;
let Test.13 = lowlevel And Test.15 Test.8; let Test.11 = lowlevel And Test.13 Test.6;
let Test.11 = 3i64; let Test.9 = 3i64;
let Test.12 = Index 1 Test.3; let Test.10 = Index 1 Test.2;
let Test.14 = lowlevel Eq Test.11 Test.12; let Test.12 = lowlevel Eq Test.9 Test.10;
let Test.7 = lowlevel And Test.14 Test.13; let Test.5 = lowlevel And Test.12 Test.11;
if Test.7 then if Test.5 then
let Test.5 = 9i64; let Test.3 = 9i64;
jump Test.4 Test.5; ret Test.3;
else else
let Test.0 = Index 0 Test.3; let Test.0 = Index 0 Test.2;
let Test.1 = Index 1 Test.3; let Test.1 = Index 1 Test.2;
let Test.6 = CallByName Num.14 Test.0 Test.1; let Test.4 = CallByName Num.14 Test.0 Test.1;
jump Test.4 Test.6; ret Test.4;
joinpoint Test.4 Test.2:
ret Test.2;
"# "#
), ),
) )
@ -683,8 +585,8 @@ mod test_mono {
ret Test.9; ret Test.9;
procedure Num.14 (#Attr.2, #Attr.3): procedure Num.14 (#Attr.2, #Attr.3):
let Test.10 = lowlevel NumAdd #Attr.2 #Attr.3; let Test.11 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.10; ret Test.11;
let Test.8 = 1f64; let Test.8 = 1f64;
let Test.1 = Array [Test.8]; let Test.1 = Array [Test.8];
@ -752,14 +654,12 @@ mod test_mono {
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.3 = true; let Test.1 = true;
if Test.3 then if Test.1 then
let Test.1 = 1i64; let Test.2 = 1i64;
jump Test.2 Test.1; ret Test.2;
else else
let Test.1 = 2i64; let Test.0 = 2i64;
jump Test.2 Test.1;
joinpoint Test.2 Test.0:
ret Test.0; ret Test.0;
"# "#
), ),
@ -779,21 +679,17 @@ mod test_mono {
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.6 = true; let Test.3 = true;
if Test.6 then if Test.3 then
let Test.1 = 1i64; let Test.4 = 1i64;
jump Test.2 Test.1; ret Test.4;
else else
let Test.5 = false; let Test.1 = false;
if Test.5 then if Test.1 then
let Test.3 = 2i64; let Test.2 = 2i64;
jump Test.4 Test.3; ret Test.2;
else else
let Test.3 = 3i64; let Test.0 = 3i64;
jump Test.4 Test.3;
joinpoint Test.4 Test.1:
jump Test.2 Test.1;
joinpoint Test.2 Test.0:
ret Test.0; ret Test.0;
"# "#
), ),
@ -904,24 +800,22 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure Bool.5 (#Attr.2, #Attr.3): procedure Bool.5 (#Attr.2, #Attr.3):
let Test.12 = lowlevel Eq #Attr.2 #Attr.3; let Test.10 = lowlevel Eq #Attr.2 #Attr.3;
ret Test.12; ret Test.10;
let Test.2 = 10i64; let Test.1 = 10i64;
let Test.10 = true; let Test.8 = true;
let Test.7 = 5i64; let Test.5 = 5i64;
let Test.6 = CallByName Bool.5 Test.2 Test.7; let Test.4 = CallByName Bool.5 Test.1 Test.5;
jump Test.5 Test.6;
joinpoint Test.5 Test.11:
let Test.9 = lowlevel And Test.11 Test.10;
if Test.9 then
let Test.4 = 0i64;
jump Test.3 Test.4; jump Test.3 Test.4;
joinpoint Test.3 Test.9:
let Test.7 = lowlevel And Test.9 Test.8;
if Test.7 then
let Test.2 = 0i64;
ret Test.2;
else else
let Test.8 = 42i64; let Test.6 = 42i64;
jump Test.3 Test.8; ret Test.6;
joinpoint Test.3 Test.1:
ret Test.1;
"# "#
), ),
) )
@ -976,17 +870,15 @@ mod test_mono {
), ),
indoc!( indoc!(
r#" r#"
let Test.2 = 0i64; let Test.1 = 0i64;
let Test.7 = true; let Test.5 = true;
let Test.8 = 1i64; let Test.6 = 1i64;
let Test.9 = lowlevel Eq Test.8 Test.2; let Test.7 = lowlevel Eq Test.6 Test.1;
let Test.6 = lowlevel And Test.9 Test.7; let Test.4 = lowlevel And Test.7 Test.5;
if Test.6 then if Test.4 then
let Test.4 = 12i64; let Test.2 = 12i64;
jump Test.3 Test.4; ret Test.2;
else else
jump Test.3 Test.2;
joinpoint Test.3 Test.1:
ret Test.1; ret Test.1;
"# "#
), ),
@ -1010,15 +902,13 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure List.4 (#Attr.2, #Attr.3, #Attr.4): procedure List.4 (#Attr.2, #Attr.3, #Attr.4):
let Test.15 = lowlevel ListLen #Attr.2; let Test.14 = lowlevel ListLen #Attr.2;
let Test.14 = lowlevel NumLt #Attr.3 Test.15; let Test.12 = lowlevel NumLt #Attr.3 Test.14;
if Test.14 then if Test.12 then
let Test.12 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4; let Test.13 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4;
jump Test.13 Test.12; ret Test.13;
else else
jump Test.13 #Attr.2; ret #Attr.2;
joinpoint Test.13 Test.11:
ret Test.11;
procedure Test.1 (Test.3): procedure Test.1 (Test.3):
let Test.9 = 0i64; let Test.9 = 0i64;
@ -1179,4 +1069,187 @@ mod test_mono {
), ),
) )
} }
#[allow(dead_code)]
fn quicksort_help() {
crate::helpers::with_larger_debug_stack(|| {
compiles_to_ir(
indoc!(
r#"
quicksortHelp : List (Num a), Int, Int -> List (Num a)
quicksortHelp = \list, low, high ->
if low < high then
(Pair partitionIndex partitioned) = Pair 0 []
partitioned
|> quicksortHelp low (partitionIndex - 1)
|> quicksortHelp (partitionIndex + 1) high
else
list
quicksortHelp [] 0 0
"#
),
indoc!(
r#"
"#
),
)
})
}
#[allow(dead_code)]
fn quicksort_partition_help() {
crate::helpers::with_larger_debug_stack(|| {
compiles_to_ir(
indoc!(
r#"
partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ]
partitionHelp = \i, j, list, high, pivot ->
if j < high then
when List.get list j is
Ok value ->
if value <= pivot then
partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot
else
partitionHelp i (j + 1) list high pivot
Err _ ->
Pair i list
else
Pair i list
partitionHelp 0 0 [] 0 0
"#
),
indoc!(
r#"
"#
),
)
})
}
#[allow(dead_code)]
fn quicksort_full() {
crate::helpers::with_larger_debug_stack(|| {
compiles_to_ir(
indoc!(
r#"
quicksort = \originalList ->
quicksortHelp : List (Num a), Int, Int -> List (Num a)
quicksortHelp = \list, low, high ->
if low < high then
(Pair partitionIndex partitioned) = partition low high list
partitioned
|> quicksortHelp low (partitionIndex - 1)
|> quicksortHelp (partitionIndex + 1) high
else
list
swap : Int, Int, List a -> List a
swap = \i, j, list ->
when Pair (List.get list i) (List.get list j) is
Pair (Ok atI) (Ok atJ) ->
list
|> List.set i atJ
|> List.set j atI
_ ->
[]
partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ]
partition = \low, high, initialList ->
when List.get initialList high is
Ok pivot ->
when partitionHelp (low - 1) low initialList high pivot is
Pair newI newList ->
Pair (newI + 1) (swap (newI + 1) high newList)
Err _ ->
Pair (low - 1) initialList
partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ]
partitionHelp = \i, j, list, high, pivot ->
if j < high then
when List.get list j is
Ok value ->
if value <= pivot then
partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot
else
partitionHelp i (j + 1) list high pivot
Err _ ->
Pair i list
else
Pair i list
n = List.len originalList
quicksortHelp originalList 0 (n - 1)
quicksort [1,2,3]
"#
),
indoc!(
r#"
"#
),
)
})
}
#[test]
fn factorial() {
compiles_to_ir(
r#"
factorial = \n, accum ->
when n is
0 ->
accum
_ ->
factorial (n - 1) (n * accum)
factorial 10 1
"#,
indoc!(
r#"
procedure Test.0 (Test.2, Test.3):
jump Test.20 Test.2 Test.3;
joinpoint Test.20 Test.2 Test.3:
let Test.17 = true;
let Test.18 = 0i64;
let Test.19 = lowlevel Eq Test.18 Test.2;
let Test.16 = lowlevel And Test.19 Test.17;
if Test.16 then
ret Test.3;
else
let Test.13 = 1i64;
let Test.9 = CallByName Num.15 Test.2 Test.13;
let Test.10 = CallByName Num.16 Test.2 Test.3;
jump Test.20 Test.9 Test.10;
procedure Num.15 (#Attr.2, #Attr.3):
let Test.14 = lowlevel NumSub #Attr.2 #Attr.3;
ret Test.14;
procedure Num.16 (#Attr.2, #Attr.3):
let Test.11 = lowlevel NumMul #Attr.2 #Attr.3;
ret Test.11;
let Test.5 = 10i64;
let Test.6 = 1i64;
let Test.4 = CallByName Test.0 Test.5 Test.6;
ret Test.4;
"#
),
)
}
} }

View file

@ -7,7 +7,7 @@ extern "C" {
fn quicksort(list: Box<[i64]>) -> Box<[i64]>; fn quicksort(list: Box<[i64]>) -> Box<[i64]>;
} }
const NUM_NUMS: usize = 1_000_00; const NUM_NUMS: usize = 1_000_000;
pub fn main() { pub fn main() {
let nums: Box<[i64]> = { let nums: Box<[i64]> = {
@ -18,10 +18,8 @@ pub fn main() {
nums.push(num); nums.push(num);
} }
nums.into()
nums };
}
.into();
println!("Running Roc quicksort on {} numbers...", nums.len()); println!("Running Roc quicksort on {} numbers...", nums.len());
let start_time = SystemTime::now(); let start_time = SystemTime::now();