start work on tail recursion

This commit is contained in:
Folkert 2020-08-12 12:57:20 +02:00
parent 8444c1fe6d
commit bdd8751107
5 changed files with 119 additions and 55 deletions

View file

@ -117,7 +117,10 @@ pub fn gen(
Declare(def) | Builtin(def) => match def.loc_pattern.value { Declare(def) | Builtin(def) => match def.loc_pattern.value {
Identifier(symbol) => { Identifier(symbol) => {
match def.loc_expr.value { match def.loc_expr.value {
Closure(annotation, _, _, loc_args, boxed_body) => { Closure(annotation, _, recursivity, loc_args, boxed_body) => {
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
// If this is an exposed symbol, we need to // If this is an exposed symbol, we need to
@ -155,6 +158,7 @@ pub fn gen(
annotation, annotation,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
} }
@ -166,6 +170,7 @@ pub fn gen(
pattern_symbols: bumpalo::collections::Vec::new_in( pattern_symbols: bumpalo::collections::Vec::new_in(
mono_env.arena, mono_env.arena,
), ),
is_tail_recursive: false,
body, body,
}; };

View file

@ -173,7 +173,7 @@ pub fn construct_optimization_passes<'a>(
} }
OptLevel::Optimize => { OptLevel::Optimize => {
// this threshold seems to do what we want // this threshold seems to do what we want
pmb.set_inliner_with_threshold(0); pmb.set_inliner_with_threshold(2);
// TODO figure out which of these actually help // TODO figure out which of these actually help
@ -185,21 +185,18 @@ pub fn construct_optimization_passes<'a>(
fpm.add_jump_threading_pass(); fpm.add_jump_threading_pass();
mpm.add_jump_threading_pass(); mpm.add_jump_threading_pass();
//fpm.add_ind_var_simplify_pass();
fpm.add_memcpy_optimize_pass(); // this one is very important fpm.add_memcpy_optimize_pass(); // this one is very important
// In my testing, these don't do much for quicksort // In my testing, these don't do much for quicksort
//fpm.add_ind_var_simplify_pass();
// fpm.add_basic_alias_analysis_pass(); // fpm.add_basic_alias_analysis_pass();
// fpm.add_jump_threading_pass(); // fpm.add_jump_threading_pass();
// fpm.add_instruction_combining_pass();
// fpm.add_licm_pass(); // fpm.add_licm_pass();
// fpm.add_loop_unroll_pass();
// fpm.add_scalar_repl_aggregates_pass_ssa(); // fpm.add_scalar_repl_aggregates_pass_ssa();
// fpm.add_cfg_simplification_pass(); // fpm.add_cfg_simplification_pass();
// fpm.add_jump_threading_pass(); // fpm.add_jump_threading_pass();
// module passes // module passes
// fpm.add_promote_memory_to_register_pass();
} }
} }

View file

@ -265,7 +265,7 @@ mod gen_list {
assert_evals_to!( assert_evals_to!(
&format!("List.concat {} {}", slice_str1, slice_str2), &format!("List.concat {} {}", slice_str1, slice_str2),
expected_slice, expected_slice,
&'static [i64], &'static [i64]
); );
} }
@ -814,7 +814,7 @@ mod gen_list {
"# "#
), ),
&[4, 7, 19, 21], &[4, 7, 19, 21],
&'static [i64], &'static [i64]
); );
}) })
} }
@ -1012,7 +1012,7 @@ mod gen_list {
"# "#
), ),
&[1, 2, 3], &[1, 2, 3],
&'static [i64], &'static [i64]
); );
} }
@ -1031,7 +1031,7 @@ mod gen_list {
"# "#
), ),
&[0, 2, 3], &[0, 2, 3],
&'static [i64], &'static [i64]
); );
} }
@ -1048,7 +1048,7 @@ mod gen_list {
"# "#
), ),
&[1, 2, 3], &[1, 2, 3],
&'static [i64], &'static [i64]
); );
} }
} }

View file

@ -23,6 +23,7 @@ pub struct PartialProc<'a> {
pub annotation: Variable, pub annotation: Variable,
pub pattern_symbols: Vec<'a, Symbol>, pub pattern_symbols: Vec<'a, Symbol>,
pub body: roc_can::expr::Expr, pub body: roc_can::expr::Expr,
pub is_tail_recursive: bool,
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -39,6 +40,7 @@ pub struct Proc<'a> {
pub body: Stmt<'a>, pub body: Stmt<'a>,
pub closes_over: Layout<'a>, pub closes_over: Layout<'a>,
pub ret_layout: Layout<'a>, pub ret_layout: Layout<'a>,
pub is_tail_recursive: bool,
} }
impl<'a> Proc<'a> { impl<'a> Proc<'a> {
@ -128,6 +130,7 @@ impl<'a> Procs<'a> {
annotation: Variable, annotation: Variable,
loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>, loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>,
loc_body: Located<roc_can::expr::Expr>, loc_body: Located<roc_can::expr::Expr>,
is_tail_recursive: bool,
ret_var: Variable, ret_var: Variable,
) { ) {
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
@ -142,6 +145,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive,
}, },
); );
} }
@ -174,6 +178,9 @@ impl<'a> Procs<'a> {
ret_var: Variable, ret_var: Variable,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
) -> Result<Layout<'a>, RuntimeError> { ) -> Result<Layout<'a>, RuntimeError> {
// anonymous functions cannot reference themselves, therefore cannot be tail-recursive
let is_tail_recursive = false;
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
Ok((pattern_vars, pattern_symbols, body)) => { Ok((pattern_vars, pattern_symbols, body)) => {
// an anonymous closure. These will always be specialized already // an anonymous closure. These will always be specialized already
@ -212,6 +219,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive,
}, },
); );
} }
@ -221,6 +229,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive,
}; };
// Mark this proc as in-progress, so if we're dealing with // Mark this proc as in-progress, so if we're dealing with
@ -991,6 +1000,7 @@ fn specialize<'a>(
annotation, annotation,
pattern_symbols, pattern_symbols,
body, body,
is_tail_recursive,
} = partial_proc; } = partial_proc;
// unify the called function with the specialized signature, then specialize the function body // unify the called function with the specialized signature, then specialize the function body
@ -1034,6 +1044,7 @@ fn specialize<'a>(
body: specialized_body, body: specialized_body,
closes_over: closes_over_layout, closes_over: closes_over_layout,
ret_layout, ret_layout,
is_tail_recursive,
}; };
Ok(proc) Ok(proc)
@ -1089,16 +1100,15 @@ pub fn with_hole<'a>(
}, },
LetNonRec(def, cont, _, _) => { LetNonRec(def, cont, _, _) => {
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
if let Closure(_, _, _, _, _) = &def.loc_expr.value { if let Closure(ann, _, recursivity, loc_args, boxed_body) = def.loc_expr.value {
// Now that we know for sure it's a closure, get an owned
// version of these variant args so we can use them properly.
match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -1106,14 +1116,12 @@ pub fn with_hole<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
return with_hole(env, cont.value, procs, layout_cache, assigned, hole); return with_hole(env, cont.value, procs, layout_cache, assigned, hole);
} }
_ => unreachable!(),
}
}
} }
if let roc_can::pattern::Pattern::Identifier(symbol) = def.loc_pattern.value { if let roc_can::pattern::Pattern::Identifier(symbol) = def.loc_pattern.value {
@ -1176,15 +1184,15 @@ pub fn with_hole<'a>(
// because Roc is strict, only functions can be recursive! // because Roc is strict, only functions can be recursive!
for def in defs.into_iter() { for def in defs.into_iter() {
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value { if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
// Now that we know for sure it's a closure, get an owned if let Closure(ann, _, recursivity, loc_args, boxed_body) = def.loc_expr.value {
// version of these variant args so we can use them properly.
match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -1192,13 +1200,12 @@ pub fn with_hole<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
continue; continue;
} }
_ => unreachable!("recursive value is not a function"),
}
} }
unreachable!("recursive value does not have Identifier pattern") unreachable!("recursive value does not have Identifier pattern")
} }
@ -2025,12 +2032,15 @@ pub fn from_can<'a>(
// Now that we know for sure it's a closure, get an owned // Now that we know for sure it's a closure, get an owned
// version of these variant args so we can use them properly. // version of these variant args so we can use them properly.
match def.loc_expr.value { match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => { Closure(ann, _, recursivity, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -2038,6 +2048,7 @@ pub fn from_can<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
@ -2057,12 +2068,15 @@ pub fn from_can<'a>(
// Now that we know for sure it's a closure, get an owned // Now that we know for sure it's a closure, get an owned
// version of these variant args so we can use them properly. // version of these variant args so we can use them properly.
match def.loc_expr.value { match def.loc_expr.value {
Closure(ann, _, _, loc_args, boxed_body) => { Closure(ann, _, recursivity, loc_args, boxed_body) => {
// Extract Procs, but discard the resulting Expr::Load. // Extract Procs, but discard the resulting Expr::Load.
// That Load looks up the pointer, which we won't use here! // That Load looks up the pointer, which we won't use here!
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive);
procs.insert_named( procs.insert_named(
env, env,
layout_cache, layout_cache,
@ -2070,6 +2084,7 @@ pub fn from_can<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive,
ret_var, ret_var,
); );
@ -2765,11 +2780,11 @@ fn store_pattern<'a>(
} }
Shadowed(_region, _ident) => { Shadowed(_region, _ident) => {
return Err(&"TODO"); return Err(&"shadowed");
} }
UnsupportedPattern(_region) => { UnsupportedPattern(_region) => {
return Err(&"TODO"); return Err(&"unsupported pattern");
} }
} }

View file

@ -1064,4 +1064,51 @@ mod test_mono {
) )
}) })
} }
#[test]
fn factorial() {
compiles_to_ir(
r#"
factorial = \n, accum ->
when n is
0 ->
accum
_ ->
factorial (n - 1) (n * accum)
factorial 10 1
"#,
indoc!(
r#"
procedure Test.0 (Test.2, Test.3):
let Test.15 = true;
let Test.16 = 0i64;
let Test.17 = lowlevel Eq Test.16 Test.2;
let Test.14 = lowlevel And Test.17 Test.15;
if Test.14 then
ret Test.3;
else
let Test.12 = 1i64;
let Test.9 = CallByName Num.15 Test.2 Test.12;
let Test.10 = CallByName Num.16 Test.2 Test.3;
let Test.8 = CallByName Test.0 Test.9 Test.10;
ret Test.8;
procedure Num.15 (#Attr.2, #Attr.3):
let Test.13 = lowlevel NumSub #Attr.2 #Attr.3;
ret Test.13;
procedure Num.16 (#Attr.2, #Attr.3):
let Test.11 = lowlevel NumMul #Attr.2 #Attr.3;
ret Test.11;
let Test.5 = 10i64;
let Test.6 = 1i64;
let Test.4 = CallByName Test.0 Test.5 Test.6;
ret Test.4;
"#
),
)
}
} }