Merge branch 'trunk' into str-concat

This commit is contained in:
Richard Feldman 2020-08-19 23:05:13 -04:00 committed by GitHub
commit 50251c678b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 1565 additions and 386 deletions

View file

@ -258,8 +258,13 @@ pub fn gen(src: &[u8], target: Triple, opt_level: OptLevel) -> Result<(String, S
}; };
let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs);
let main_body =
roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body)); let param_map = roc_mono::borrow::ParamMap::default();
let main_body = roc_mono::inc_dec::visit_declaration(
mono_env.arena,
mono_env.arena.alloc(param_map),
mono_env.arena.alloc(main_body),
);
let mut headers = { let mut headers = {
let num_headers = match &procs.pending_specializations { let num_headers = match &procs.pending_specializations {
Some(map) => map.len(), Some(map) => map.len(),

View file

@ -158,7 +158,7 @@ pub fn gen(
pattern_symbols: bumpalo::collections::Vec::new_in( pattern_symbols: bumpalo::collections::Vec::new_in(
mono_env.arena, mono_env.arena,
), ),
is_tail_recursive: false, is_self_recursive: false,
body, body,
}; };

View file

@ -37,6 +37,9 @@ const PRINT_FN_VERIFICATION_OUTPUT: bool = true;
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
const PRINT_FN_VERIFICATION_OUTPUT: bool = false; const PRINT_FN_VERIFICATION_OUTPUT: bool = false;
pub const REFCOUNT_0: usize = std::usize::MAX;
pub const REFCOUNT_1: usize = REFCOUNT_0 - 1;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum OptLevel { pub enum OptLevel {
Normal, Normal,
@ -904,7 +907,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
match layout { match layout {
Layout::Builtin(Builtin::List(MemoryMode::Refcounted, _)) => { Layout::Builtin(Builtin::List(MemoryMode::Refcounted, _)) => {
increment_refcount_list(env, value.into_struct_value()); increment_refcount_list(env, parent, value.into_struct_value());
build_exp_stmt(env, layout_ids, scope, parent, cont) build_exp_stmt(env, layout_ids, scope, parent, cont)
} }
_ => build_exp_stmt(env, layout_ids, scope, parent, cont), _ => build_exp_stmt(env, layout_ids, scope, parent, cont),
@ -929,11 +932,7 @@ fn refcount_is_one_comparison<'ctx>(
context: &'ctx Context, context: &'ctx Context,
refcount: IntValue<'ctx>, refcount: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let refcount_one: IntValue<'ctx> = context.i64_type().const_int((std::usize::MAX) as _, false); let refcount_one: IntValue<'ctx> = context.i64_type().const_int(REFCOUNT_1 as _, false);
// Note: Check for refcount < refcount_1 as the "true" condition,
// to avoid misprediction. (In practice this should usually pass,
// and CPUs generally default to predicting that a forward jump
// shouldn't be taken; that is, they predict "else" won't be taken.)
builder.build_int_compare( builder.build_int_compare(
IntPredicate::EQ, IntPredicate::EQ,
refcount, refcount,
@ -998,6 +997,7 @@ fn decrement_refcount_layout<'a, 'ctx, 'env>(
} }
} }
} }
RecursiveUnion(_) => todo!("TODO implement decrement layout of recursive tag union"),
Union(tags) => { Union(tags) => {
debug_assert!(!tags.is_empty()); debug_assert!(!tags.is_empty());
let wrapper_struct = value.into_struct_value(); let wrapper_struct = value.into_struct_value();
@ -1086,11 +1086,29 @@ fn decrement_refcount_builtin<'a, 'ctx, 'env>(
fn increment_refcount_list<'a, 'ctx, 'env>( fn increment_refcount_list<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
parent: FunctionValue<'ctx>,
original_wrapper: StructValue<'ctx>, original_wrapper: StructValue<'ctx>,
) { ) {
let builder = env.builder; let builder = env.builder;
let ctx = env.context; let ctx = env.context;
let len = list_len(builder, original_wrapper);
let is_non_empty = builder.build_int_compare(
IntPredicate::UGT,
len,
ctx.i64_type().const_zero(),
"len > 0",
);
// build blocks
let increment_block = ctx.append_basic_block(parent, "increment_block");
let cont_block = ctx.append_basic_block(parent, "after_increment_block");
builder.build_conditional_branch(is_non_empty, increment_block, cont_block);
builder.position_at_end(increment_block);
let refcount_ptr = list_get_refcount_ptr(env, original_wrapper); let refcount_ptr = list_get_refcount_ptr(env, original_wrapper);
let refcount = env let refcount = env
@ -1107,6 +1125,9 @@ fn increment_refcount_list<'a, 'ctx, 'env>(
// Mutate the new array in-place to change the element. // Mutate the new array in-place to change the element.
builder.build_store(refcount_ptr, decremented); builder.build_store(refcount_ptr, decremented);
builder.build_unconditional_branch(cont_block);
builder.position_at_end(cont_block);
} }
fn decrement_refcount_list<'a, 'ctx, 'env>( fn decrement_refcount_list<'a, 'ctx, 'env>(
@ -1117,6 +1138,30 @@ fn decrement_refcount_list<'a, 'ctx, 'env>(
let builder = env.builder; let builder = env.builder;
let ctx = env.context; let ctx = env.context;
// the block we'll always jump to when we're done
let cont_block = ctx.append_basic_block(parent, "after_decrement_block");
let decrement_block = ctx.append_basic_block(parent, "decrement_block");
// currently, an empty list has a null-pointer in its length is 0
// so we must first check the length
let len = list_len(builder, original_wrapper);
let is_non_empty = builder.build_int_compare(
IntPredicate::UGT,
len,
ctx.i64_type().const_zero(),
"len > 0",
);
// if the length is 0, we're done and jump to the continuation block
// otherwise, actually read and check the refcount
builder.build_conditional_branch(is_non_empty, decrement_block, cont_block);
builder.position_at_end(decrement_block);
// build blocks
let then_block = ctx.append_basic_block(parent, "then");
let else_block = ctx.append_basic_block(parent, "else");
let refcount_ptr = list_get_refcount_ptr(env, original_wrapper); let refcount_ptr = list_get_refcount_ptr(env, original_wrapper);
let refcount = env let refcount = env
@ -1126,16 +1171,24 @@ fn decrement_refcount_list<'a, 'ctx, 'env>(
let comparison = refcount_is_one_comparison(builder, env.context, refcount); let comparison = refcount_is_one_comparison(builder, env.context, refcount);
// build blocks // TODO what would be most optimial for the branch predictor
let then_block = ctx.append_basic_block(parent, "then"); //
let else_block = ctx.append_basic_block(parent, "else"); // are most refcounts 1 most of the time? or not?
let cont_block = ctx.append_basic_block(parent, "dec_ref_branchcont");
builder.build_conditional_branch(comparison, then_block, else_block); builder.build_conditional_branch(comparison, then_block, else_block);
// build then block // build then block
{ {
builder.position_at_end(then_block); builder.position_at_end(then_block);
if !env.leak {
let free = builder.build_free(refcount_ptr);
builder.insert_instruction(&free, None);
}
builder.build_unconditional_branch(cont_block);
}
// build else block
{
builder.position_at_end(else_block);
// our refcount 0 is actually usize::MAX, so decrementing the refcount means incrementing this value. // our refcount 0 is actually usize::MAX, so decrementing the refcount means incrementing this value.
let decremented = env.builder.build_int_add( let decremented = env.builder.build_int_add(
ctx.i64_type().const_int(1 as u64, false), ctx.i64_type().const_int(1 as u64, false),
@ -1149,16 +1202,6 @@ fn decrement_refcount_list<'a, 'ctx, 'env>(
builder.build_unconditional_branch(cont_block); builder.build_unconditional_branch(cont_block);
} }
// build else block
{
builder.position_at_end(else_block);
if !env.leak {
let free = builder.build_free(refcount_ptr);
builder.insert_instruction(&free, None);
}
builder.build_unconditional_branch(cont_block);
}
// emit merge block // emit merge block
builder.position_at_end(cont_block); builder.position_at_end(cont_block);
} }
@ -1804,14 +1847,9 @@ fn run_low_level<'a, 'ctx, 'env>(
list_get_unsafe(env, list_layout, elem_index, wrapper_struct) list_get_unsafe(env, list_layout, elem_index, wrapper_struct)
} }
ListSet => { ListSetInPlace => {
let (list_symbol, list_layout) = load_symbol_and_layout(env, scope, &args[0]); let (list_symbol, list_layout) = load_symbol_and_layout(env, scope, &args[0]);
let in_place = match &list_layout {
Layout::Builtin(Builtin::List(MemoryMode::Unique, _)) => InPlace::InPlace,
_ => InPlace::Clone,
};
list_set( list_set(
parent, parent,
&[ &[
@ -1820,19 +1858,57 @@ fn run_low_level<'a, 'ctx, 'env>(
(load_symbol_and_layout(env, scope, &args[2])), (load_symbol_and_layout(env, scope, &args[2])),
], ],
env, env,
in_place, InPlace::InPlace,
) )
} }
ListSetInPlace => list_set( ListSet => {
parent, let (list_symbol, list_layout) = load_symbol_and_layout(env, scope, &args[0]);
&[
(load_symbol_and_layout(env, scope, &args[0])), let arguments = &[
(list_symbol, list_layout),
(load_symbol_and_layout(env, scope, &args[1])), (load_symbol_and_layout(env, scope, &args[1])),
(load_symbol_and_layout(env, scope, &args[2])), (load_symbol_and_layout(env, scope, &args[2])),
], ];
env,
InPlace::InPlace, match list_layout {
), Layout::Builtin(Builtin::List(MemoryMode::Unique, _)) => {
// the layout tells us this List.set can be done in-place
list_set(parent, arguments, env, InPlace::InPlace)
}
Layout::Builtin(Builtin::List(MemoryMode::Refcounted, _)) => {
// no static guarantees, but all is not lost: we can check the refcount
// if it is one, we hold the final reference, and can mutate it in-place!
let builder = env.builder;
let ctx = env.context;
let ret_type =
basic_type_from_layout(env.arena, ctx, list_layout, env.ptr_bytes);
let refcount_ptr = list_get_refcount_ptr(env, list_symbol.into_struct_value());
let refcount = env
.builder
.build_load(refcount_ptr, "get_refcount")
.into_int_value();
let comparison = refcount_is_one_comparison(builder, env.context, refcount);
// build then block
// refcount is 1, so work in-place
let build_pass = || list_set(parent, arguments, env, InPlace::InPlace);
// build else block
// refcount != 1, so clone first
let build_fail = || list_set(parent, arguments, env, InPlace::Clone);
crate::llvm::build_list::build_basic_phi2(
env, parent, comparison, build_pass, build_fail, ret_type,
)
}
Layout::Builtin(Builtin::EmptyList) => list_symbol,
other => unreachable!("List.set: weird layout {:?}", other),
}
}
} }
} }

View file

@ -207,10 +207,7 @@ pub fn list_prepend<'a, 'ctx, 'env>(
let ptr_bytes = env.ptr_bytes; let ptr_bytes = env.ptr_bytes;
// Allocate space for the new array that we'll copy into. // Allocate space for the new array that we'll copy into.
let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); let clone_ptr = allocate_list(env, elem_layout, new_list_len);
let clone_ptr = builder
.build_array_malloc(elem_type, new_list_len, "list_ptr")
.unwrap();
let int_type = ptr_int(ctx, ptr_bytes); let int_type = ptr_int(ctx, ptr_bytes);
let ptr_as_int = builder.build_ptr_to_int(clone_ptr, int_type, "list_cast_ptr"); let ptr_as_int = builder.build_ptr_to_int(clone_ptr, int_type, "list_cast_ptr");
@ -355,9 +352,7 @@ pub fn list_join<'a, 'ctx, 'env>(
.build_load(list_len_sum_alloca, list_len_sum_name) .build_load(list_len_sum_alloca, list_len_sum_name)
.into_int_value(); .into_int_value();
let final_list_ptr = builder let final_list_ptr = allocate_list(env, elem_layout, final_list_sum);
.build_array_malloc(elem_type, final_list_sum, "final_list_sum")
.unwrap();
let dest_elem_ptr_alloca = builder.build_alloca(elem_ptr_type, "dest_elem"); let dest_elem_ptr_alloca = builder.build_alloca(elem_ptr_type, "dest_elem");
@ -1375,9 +1370,12 @@ pub fn allocate_list<'a, 'ctx, 'env>(
"make ptr", "make ptr",
); );
// put our "refcount 0" in the first slot // the refcount of a new list is initially 1
let ref_count_zero = ctx.i64_type().const_int(std::usize::MAX as u64, false); // we assume that the list is indeed used (dead variables are eliminated)
builder.build_store(refcount_ptr, ref_count_zero); let ref_count_one = ctx
.i64_type()
.const_int(crate::llvm::build::REFCOUNT_1 as _, false);
builder.build_store(refcount_ptr, ref_count_one);
list_element_ptr list_element_ptr
} }

View file

@ -107,6 +107,7 @@ pub fn basic_type_from_layout<'ctx>(
.struct_type(field_types.into_bump_slice(), false) .struct_type(field_types.into_bump_slice(), false)
.as_basic_type_enum() .as_basic_type_enum()
} }
RecursiveUnion(_) => todo!("TODO implement layout of recursive tag union"),
Union(_) => { Union(_) => {
// TODO make this dynamic // TODO make this dynamic
let ptr_size = std::mem::size_of::<i64>(); let ptr_size = std::mem::size_of::<i64>();

View file

@ -210,7 +210,28 @@ mod gen_list {
} }
#[test] #[test]
fn list_concat() { fn foobarbaz() {
assert_evals_to!(
indoc!(
r#"
firstList : List Int
firstList =
[]
secondList : List Int
secondList =
[]
List.concat firstList secondList
"#
),
&[],
&'static [i64]
);
}
#[test]
fn list_concat_vanilla() {
assert_evals_to!("List.concat [] []", &[], &'static [i64]); assert_evals_to!("List.concat [] []", &[], &'static [i64]);
assert_evals_to!( assert_evals_to!(
@ -516,7 +537,7 @@ mod gen_list {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
shared = [ 2.1, 4.3 ] main = \shared ->
# This should not mutate the original # This should not mutate the original
x = x =
@ -530,6 +551,8 @@ mod gen_list {
Err _ -> 0 Err _ -> 0
{ x, y } { x, y }
main [ 2.1, 4.3 ]
"# "#
), ),
(7.7, 4.3), (7.7, 4.3),
@ -542,6 +565,7 @@ mod gen_list {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
main = \{} ->
shared = [ 2, 4 ] shared = [ 2, 4 ]
# This List.set is out of bounds, and should have no effect # This List.set is out of bounds, and should have no effect
@ -556,6 +580,8 @@ mod gen_list {
Err _ -> 0 Err _ -> 0
{ x, y } { x, y }
main {}
"# "#
), ),
(4, 4), (4, 4),

View file

@ -482,9 +482,12 @@ mod gen_num {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
when 10 is main = \{} ->
x if x == 5 -> 0 when 10 is
_ -> 42 x if x == 5 -> 0
_ -> 42
main {}
"# "#
), ),
42, 42,
@ -497,9 +500,12 @@ mod gen_num {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
when 10 is main = \{} ->
x if x == 10 -> 42 when 10 is
_ -> 0 x if x == 10 -> 42
_ -> 0
main {}
"# "#
), ),
42, 42,

View file

@ -283,7 +283,10 @@ mod gen_primitives {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
main = \{} ->
(\a -> a) 5 (\a -> a) 5
main {}
"# "#
), ),
5, 5,
@ -296,11 +299,14 @@ mod gen_primitives {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
main = \{} ->
alwaysFloatIdentity : Int -> (Float -> Float) alwaysFloatIdentity : Int -> (Float -> Float)
alwaysFloatIdentity = \num -> alwaysFloatIdentity = \num ->
(\a -> a) (\a -> a)
(alwaysFloatIdentity 2) 3.14 (alwaysFloatIdentity 2) 3.14
main {}
"# "#
), ),
3.14, 3.14,

View file

@ -455,9 +455,12 @@ mod gen_tags {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
when 2 is main = \{} ->
2 if False -> 0 when 2 is
_ -> 42 2 if False -> 0
_ -> 42
main {}
"# "#
), ),
42, 42,
@ -470,9 +473,12 @@ mod gen_tags {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
when 2 is main = \{} ->
2 if True -> 42 when 2 is
_ -> 0 2 if True -> 42
_ -> 0
main {}
"# "#
), ),
42, 42,
@ -485,9 +491,12 @@ mod gen_tags {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
when 2 is main = \{} ->
_ if False -> 0 when 2 is
_ -> 42 _ if False -> 0
_ -> 42
main {}
"# "#
), ),
42, 42,
@ -665,16 +674,19 @@ mod gen_tags {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
x : [ Red, White, Blue ] main = \{} ->
x = Blue x : [ Red, White, Blue ]
x = Blue
y = y =
when x is when x is
Red -> 1 Red -> 1
White -> 2 White -> 2
Blue -> 3.1 Blue -> 3.1
y y
main {}
"# "#
), ),
3.1, 3.1,
@ -687,13 +699,16 @@ mod gen_tags {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
y = main = \{} ->
when 1 + 2 is y =
3 -> 3 when 1 + 2 is
1 -> 1 3 -> 3
_ -> 0 1 -> 1
_ -> 0
y y
main {}
"# "#
), ),
3, 3,

View file

@ -106,8 +106,6 @@ pub fn helper_without_uniqueness<'a>(
}; };
let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs);
let main_body =
roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body));
let mut headers = { let mut headers = {
let num_headers = match &procs.pending_specializations { let num_headers = match &procs.pending_specializations {
@ -125,6 +123,13 @@ pub fn helper_without_uniqueness<'a>(
roc_collections::all::MutMap::default() roc_collections::all::MutMap::default()
); );
let (mut procs, param_map) = procs.get_specialized_procs_help(mono_env.arena);
let main_body = roc_mono::inc_dec::visit_declaration(
mono_env.arena,
param_map,
mono_env.arena.alloc(main_body),
);
// Put this module's ident_ids back in the interns, so we can use them in env. // Put this module's ident_ids back in the interns, so we can use them in env.
// This must happen *after* building the headers, because otherwise there's // This must happen *after* building the headers, because otherwise there's
// a conflicting mutable borrow on ident_ids. // a conflicting mutable borrow on ident_ids.
@ -133,8 +138,7 @@ pub fn helper_without_uniqueness<'a>(
// Add all the Proc headers to the module. // Add all the Proc headers to the module.
// We have to do this in a separate pass first, // We have to do this in a separate pass first,
// because their bodies may reference each other. // because their bodies may reference each other.
for ((symbol, layout), proc) in procs.drain() {
for ((symbol, layout), proc) in procs.get_specialized_procs(env.arena).drain() {
let fn_val = build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); let fn_val = build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc);
headers.push((proc, fn_val)); headers.push((proc, fn_val));
@ -296,8 +300,6 @@ pub fn helper_with_uniqueness<'a>(
}; };
let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); let main_body = roc_mono::ir::Stmt::new(&mut mono_env, loc_expr.value, &mut procs);
let main_body =
roc_mono::inc_dec::visit_declaration(mono_env.arena, mono_env.arena.alloc(main_body));
let mut headers = { let mut headers = {
let num_headers = match &procs.pending_specializations { let num_headers = match &procs.pending_specializations {
Some(map) => map.len(), Some(map) => map.len(),
@ -314,6 +316,13 @@ pub fn helper_with_uniqueness<'a>(
roc_collections::all::MutMap::default() roc_collections::all::MutMap::default()
); );
let (mut procs, param_map) = procs.get_specialized_procs_help(mono_env.arena);
let main_body = roc_mono::inc_dec::visit_declaration(
mono_env.arena,
param_map,
mono_env.arena.alloc(main_body),
);
// Put this module's ident_ids back in the interns, so we can use them in env. // Put this module's ident_ids back in the interns, so we can use them in env.
// This must happen *after* building the headers, because otherwise there's // This must happen *after* building the headers, because otherwise there's
// a conflicting mutable borrow on ident_ids. // a conflicting mutable borrow on ident_ids.
@ -322,7 +331,7 @@ pub fn helper_with_uniqueness<'a>(
// Add all the Proc headers to the module. // Add all the Proc headers to the module.
// We have to do this in a separate pass first, // We have to do this in a separate pass first,
// because their bodies may reference each other. // because their bodies may reference each other.
for ((symbol, layout), proc) in procs.get_specialized_procs(env.arena).drain() { for ((symbol, layout), proc) in procs.drain() {
let fn_val = build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); let fn_val = build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc);
headers.push((proc, fn_val)); headers.push((proc, fn_val));

497
compiler/mono/src/borrow.rs Normal file
View file

@ -0,0 +1,497 @@
use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt};
use crate::layout::Layout;
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_collections::all::{MutMap, MutSet};
use roc_module::low_level::LowLevel;
use roc_module::symbol::Symbol;
pub fn infer_borrow<'a>(
arena: &'a Bump,
procs: &MutMap<(Symbol, Layout<'a>), Proc<'a>>,
) -> ParamMap<'a> {
let mut param_map = ParamMap {
items: MutMap::default(),
};
for proc in procs.values() {
param_map.visit_proc(arena, proc);
}
let mut env = BorrowInfState {
current_proc: Symbol::ATTR_ATTR,
param_set: MutSet::default(),
owned: MutMap::default(),
modified: false,
param_map,
arena,
};
// This is a fixed-point analysis
//
// all functions initiall own all their paramters
// through a series of checks and heuristics, some arguments are set to borrowed
// when that doesn't lead to conflicts the change is kept, otherwise it may be reverted
//
// when the signatures no longer change, the analysis stops and returns the signatures
loop {
// sort the symbols (roughly) in definition order.
// TODO in the future I think we need to do this properly, and group
// mutually recursive functions (or just make all their arguments owned)
for proc in procs.values() {
env.collect_proc(proc);
}
if !env.modified {
// if there were no modifications, we're done
break;
} else {
// otherwise see if there are changes after another iteration
env.modified = false;
}
}
env.param_map
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
enum Key {
Declaration(Symbol),
JoinPoint(JoinPointId),
}
#[derive(Debug, Clone, Default)]
pub struct ParamMap<'a> {
items: MutMap<Key, &'a [Param<'a>]>,
}
impl<'a> ParamMap<'a> {
pub fn get_symbol(&self, symbol: Symbol) -> Option<&'a [Param<'a>]> {
let key = Key::Declaration(symbol);
self.items.get(&key).copied()
}
pub fn get_join_point(&self, id: JoinPointId) -> &'a [Param<'a>] {
let key = Key::JoinPoint(id);
match self.items.get(&key) {
Some(slice) => slice,
None => unreachable!("join point not in param map: {:?}", id),
}
}
}
impl<'a> ParamMap<'a> {
fn init_borrow_params(arena: &'a Bump, ps: &'a [Param<'a>]) -> &'a [Param<'a>] {
Vec::from_iter_in(
ps.iter().map(|p| Param {
borrow: p.layout.is_refcounted(),
layout: p.layout.clone(),
symbol: p.symbol,
}),
arena,
)
.into_bump_slice()
}
fn init_borrow_args(arena: &'a Bump, ps: &'a [(Layout<'a>, Symbol)]) -> &'a [Param<'a>] {
Vec::from_iter_in(
ps.iter().map(|(layout, symbol)| Param {
borrow: layout.is_refcounted(),
layout: layout.clone(),
symbol: *symbol,
}),
arena,
)
.into_bump_slice()
}
fn visit_proc(&mut self, arena: &'a Bump, proc: &Proc<'a>) {
self.items.insert(
Key::Declaration(proc.name),
Self::init_borrow_args(arena, proc.args),
);
self.visit_stmt(arena, proc.name, &proc.body);
}
fn visit_stmt(&mut self, arena: &'a Bump, _fnid: Symbol, stmt: &Stmt<'a>) {
use Stmt::*;
let mut stack = bumpalo::vec![ in arena; stmt ];
while let Some(stmt) = stack.pop() {
match stmt {
Join {
id: j,
parameters: xs,
remainder: v,
continuation: b,
} => {
self.items
.insert(Key::JoinPoint(*j), Self::init_borrow_params(arena, xs));
stack.push(v);
stack.push(b);
}
Let(_, _, _, cont) => {
stack.push(cont);
}
Cond { pass, fail, .. } => {
stack.push(pass);
stack.push(fail);
}
Switch {
branches,
default_branch,
..
} => {
stack.extend(branches.iter().map(|b| &b.1));
stack.push(default_branch);
}
Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"),
Ret(_) | Jump(_, _) | RuntimeError(_) => {
// these are terminal, do nothing
}
}
}
}
}
// Apply the inferred borrow annotations stored in ParamMap to a block of mutually recursive procs
struct BorrowInfState<'a> {
current_proc: Symbol,
param_set: MutSet<Symbol>,
owned: MutMap<Symbol, MutSet<Symbol>>,
modified: bool,
param_map: ParamMap<'a>,
arena: &'a Bump,
}
impl<'a> BorrowInfState<'a> {
pub fn own_var(&mut self, x: Symbol) {
let current = self.owned.get_mut(&self.current_proc).unwrap();
if current.contains(&x) {
// do nothing
} else {
current.insert(x);
self.modified = true;
}
}
fn is_owned(&self, x: Symbol) -> bool {
match self.owned.get(&self.current_proc) {
None => unreachable!(
"the current procedure symbol {:?} is not in the owned map",
self.current_proc
),
Some(set) => set.contains(&x),
}
}
fn update_param_map(&mut self, k: Key) {
let arena = self.arena;
if let Some(ps) = self.param_map.items.get(&k) {
let ps = Vec::from_iter_in(
ps.iter().map(|p| {
if !p.borrow {
p.clone()
} else if self.is_owned(p.symbol) {
self.modified = true;
let mut p = p.clone();
p.borrow = false;
p
} else {
p.clone()
}
}),
arena,
);
self.param_map.items.insert(k, ps.into_bump_slice());
}
}
/// This looks at an application `f x1 x2 x3`
/// If the parameter (based on the definition of `f`) is owned,
/// then the argument must also be owned
fn own_args_using_params(&mut self, xs: &[Symbol], ps: &[Param<'a>]) {
debug_assert_eq!(xs.len(), ps.len());
for (x, p) in xs.iter().zip(ps.iter()) {
if !p.borrow {
self.own_var(*x);
}
}
}
/// This looks at an application `f x1 x2 x3`
/// If the parameter (based on the definition of `f`) is owned,
/// then the argument must also be owned
fn own_args_using_bools(&mut self, xs: &[Symbol], ps: &[bool]) {
debug_assert_eq!(xs.len(), ps.len());
for (x, borrow) in xs.iter().zip(ps.iter()) {
if !borrow {
self.own_var(*x);
}
}
}
/// For each xs[i], if xs[i] is owned, then mark ps[i] as owned.
/// We use this action to preserve tail calls. That is, if we have
/// a tail call `f xs`, if the i-th parameter is borrowed, but `xs[i]` is owned
/// we would have to insert a `dec xs[i]` after `f xs` and consequently
/// "break" the tail call.
fn own_params_using_args(&mut self, xs: &[Symbol], ps: &[Param<'a>]) {
debug_assert_eq!(xs.len(), ps.len());
for (x, p) in xs.iter().zip(ps.iter()) {
if self.is_owned(*x) {
self.own_var(p.symbol);
}
}
}
/// Mark `xs[i]` as owned if it is one of the parameters `ps`.
/// We use this action to mark function parameters that are being "packed" inside constructors.
/// This is a heuristic, and is not related with the effectiveness of the reset/reuse optimization.
/// It is useful for code such as
///
/// > def f (x y : obj) :=
/// > let z := ctor_1 x y;
/// > ret z
fn own_args_if_param(&mut self, xs: &[Symbol]) {
for x in xs.iter() {
// TODO may also be asking for the index here? see Lean
if self.param_set.contains(x) {
self.own_var(*x);
}
}
}
/// This looks at the assignement
///
/// let z = e in ...
///
/// and determines whether z and which of the symbols used in e
/// must be taken as owned paramters
fn collect_expr(&mut self, z: Symbol, e: &Expr<'a>) {
use Expr::*;
match e {
Tag { arguments: xs, .. } | Struct(xs) | Array { elems: xs, .. } => {
self.own_var(z);
// if the used symbol is an argument to the current function,
// the function must take it as an owned parameter
self.own_args_if_param(xs);
}
EmptyArray => {
self.own_var(z);
}
AccessAtIndex { structure: x, .. } => {
// if the structure (record/tag/array) is owned, the extracted value is
if self.is_owned(*x) {
self.own_var(z);
}
// if the extracted value is owned, the structure must be too
if self.is_owned(z) {
self.own_var(*x);
}
}
FunctionCall {
call_type,
args,
arg_layouts,
..
} => {
// get the borrow signature of the applied function
let ps = match self.param_map.get_symbol(call_type.get_inner()) {
Some(slice) => slice,
None => Vec::from_iter_in(
arg_layouts.iter().cloned().map(|layout| Param {
symbol: Symbol::UNDERSCORE,
borrow: false,
layout,
}),
self.arena,
)
.into_bump_slice(),
};
// the return value will be owned
self.own_var(z);
// if the function exects an owned argument (ps), the argument must be owned (args)
self.own_args_using_params(args, ps);
}
RunLowLevel(op, args) => {
// very unsure what demand RunLowLevel should place upon its arguments
self.own_var(z);
let ps = lowlevel_borrow_signature(self.arena, *op);
self.own_args_using_bools(args, ps);
}
Literal(_) | FunctionPointer(_, _) | RuntimeErrorFunction(_) => {}
}
}
fn preserve_tail_call(&mut self, x: Symbol, v: &Expr<'a>, b: &Stmt<'a>) {
if let (
Expr::FunctionCall {
call_type,
args: ys,
..
},
Stmt::Ret(z),
) = (v, b)
{
let g = call_type.get_inner();
if self.current_proc == g && x == *z {
// anonymous functions (for which the ps may not be known)
// can never be tail-recursive, so this is fine
if let Some(ps) = self.param_map.get_symbol(g) {
self.own_params_using_args(ys, ps)
}
}
}
}
fn update_param_set(&mut self, ps: &[Param<'a>]) {
for p in ps.iter() {
self.param_set.insert(p.symbol);
}
}
fn update_param_set_symbols(&mut self, ps: &[Symbol]) {
for p in ps.iter() {
self.param_set.insert(*p);
}
}
fn collect_stmt(&mut self, stmt: &Stmt<'a>) {
use Stmt::*;
match stmt {
Join {
id: j,
parameters: ys,
remainder: v,
continuation: b,
} => {
let old = self.param_set.clone();
self.update_param_set(ys);
self.collect_stmt(v);
self.param_set = old;
self.update_param_map(Key::JoinPoint(*j));
self.collect_stmt(b);
}
Let(x, Expr::FunctionPointer(fsymbol, layout), _, b) => {
// ensure that the function pointed to is in the param map
if let Some(params) = self.param_map.get_symbol(*fsymbol) {
self.param_map.items.insert(Key::Declaration(*x), params);
}
self.collect_stmt(b);
self.preserve_tail_call(*x, &Expr::FunctionPointer(*fsymbol, layout.clone()), b);
}
Let(x, v, _, b) => {
self.collect_stmt(b);
self.collect_expr(*x, v);
self.preserve_tail_call(*x, v, b);
}
Jump(j, ys) => {
let ps = self.param_map.get_join_point(*j);
// for making sure the join point can reuse
self.own_args_using_params(ys, ps);
// for making sure the tail call is preserved
self.own_params_using_args(ys, ps);
}
Cond { pass, fail, .. } => {
self.collect_stmt(pass);
self.collect_stmt(fail);
}
Switch {
branches,
default_branch,
..
} => {
for (_, b) in branches.iter() {
self.collect_stmt(b);
}
self.collect_stmt(default_branch);
}
Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"),
Ret(_) | RuntimeError(_) => {
// these are terminal, do nothing
}
}
}
fn collect_proc(&mut self, proc: &Proc<'a>) {
let old = self.param_set.clone();
let ys = Vec::from_iter_in(proc.args.iter().map(|t| t.1), self.arena).into_bump_slice();
self.update_param_set_symbols(ys);
self.current_proc = proc.name;
// ensure that current_proc is in the owned map
self.owned.entry(proc.name).or_default();
self.collect_stmt(&proc.body);
self.update_param_map(Key::Declaration(proc.name));
self.param_set = old;
}
}
pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
use LowLevel::*;
// TODO is true or false more efficient for non-refcounted layouts?
let irrelevant = false;
let owned = false;
let borrowed = true;
// Here we define the borrow signature of low-level operations
//
// - arguments with non-refcounted layouts (ints, floats) are `irrelevant`
// - arguments that we may want to update destructively must be Owned
// - other refcounted arguments are Borrowed
match op {
ListLen => arena.alloc_slice_copy(&[borrowed]),
ListSet => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]),
ListSetInPlace => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]),
ListGetUnsafe => arena.alloc_slice_copy(&[borrowed, irrelevant]),
ListSingle => arena.alloc_slice_copy(&[irrelevant]),
ListRepeat => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
ListReverse => arena.alloc_slice_copy(&[owned]),
ListConcat => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
ListAppend => arena.alloc_slice_copy(&[owned, owned]),
ListPrepend => arena.alloc_slice_copy(&[owned, owned]),
ListJoin => arena.alloc_slice_copy(&[irrelevant]),
Eq | NotEq | And | Or | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte
| NumDivUnchecked | NumRemUnchecked => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumToFloat | Not => {
arena.alloc_slice_copy(&[irrelevant])
}
}
}

View file

@ -1,3 +1,4 @@
use crate::borrow::ParamMap;
use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt}; use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt};
use crate::layout::Layout; use crate::layout::Layout;
use bumpalo::collections::Vec; use bumpalo::collections::Vec;
@ -113,7 +114,11 @@ pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet<Symbol>) {
result.extend(arguments.iter().copied()); result.extend(arguments.iter().copied());
} }
RunLowLevel(_, _) | EmptyArray | RuntimeErrorFunction(_) | Literal(_) => {} RunLowLevel(_, args) => {
result.extend(args.iter());
}
EmptyArray | RuntimeErrorFunction(_) | Literal(_) => {}
} }
} }
@ -139,7 +144,7 @@ pub struct Context<'a> {
vars: VarMap, vars: VarMap,
jp_live_vars: JPLiveVarMap, // map: join point => live variables jp_live_vars: JPLiveVarMap, // map: join point => live variables
local_context: LocalContext<'a>, // we use it to store the join point declarations local_context: LocalContext<'a>, // we use it to store the join point declarations
function_params: MutMap<Symbol, &'a [Param<'a>]>, param_map: &'a ParamMap<'a>,
} }
fn update_live_vars<'a>(expr: &Expr<'a>, v: &LiveVarSet) -> LiveVarSet { fn update_live_vars<'a>(expr: &Expr<'a>, v: &LiveVarSet) -> LiveVarSet {
@ -150,6 +155,7 @@ fn update_live_vars<'a>(expr: &Expr<'a>, v: &LiveVarSet) -> LiveVarSet {
v v
} }
/// `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs`
fn is_first_occurence(xs: &[Symbol], i: usize) -> bool { fn is_first_occurence(xs: &[Symbol], i: usize) -> bool {
match xs.get(i) { match xs.get(i) {
None => unreachable!(), None => unreachable!(),
@ -157,6 +163,9 @@ fn is_first_occurence(xs: &[Symbol], i: usize) -> bool {
} }
} }
/// Return `n`, the number of times `x` is consumed.
/// - `ys` is a sequence of instruction parameters where we search for `x`.
/// - `consumeParamPred i = true` if parameter `i` is consumed.
fn get_num_consumptions<F>(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> usize fn get_num_consumptions<F>(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> usize
where where
F: Fn(usize) -> bool, F: Fn(usize) -> bool,
@ -171,6 +180,8 @@ where
n n
} }
/// Return true if `x` also occurs in `ys` in a position that is not consumed.
/// That is, it is also passed as a borrow reference.
fn is_borrow_param_help<F>(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> bool fn is_borrow_param_help<F>(x: Symbol, ys: &[Symbol], consume_param_pred: F) -> bool
where where
F: Fn(usize) -> bool, F: Fn(usize) -> bool,
@ -182,11 +193,11 @@ where
fn is_borrow_param(x: Symbol, ys: &[Symbol], ps: &[Param]) -> bool { fn is_borrow_param(x: Symbol, ys: &[Symbol], ps: &[Param]) -> bool {
// default to owned arguments // default to owned arguments
let pred = |i: usize| match ps.get(i) { let is_owned = |i: usize| match ps.get(i) {
Some(param) => !param.borrow, Some(param) => !param.borrow,
None => true, None => unreachable!("or?"),
}; };
is_borrow_param_help(x, ys, pred) is_borrow_param_help(x, ys, is_owned)
} }
// We do not need to consume the projection of a variable that is not consumed // We do not need to consume the projection of a variable that is not consumed
@ -201,13 +212,13 @@ fn consume_expr(m: &VarMap, e: &Expr<'_>) -> bool {
} }
impl<'a> Context<'a> { impl<'a> Context<'a> {
pub fn new(arena: &'a Bump) -> Self { pub fn new(arena: &'a Bump, param_map: &'a ParamMap<'a>) -> Self {
Self { Self {
arena, arena,
vars: MutMap::default(), vars: MutMap::default(),
jp_live_vars: MutMap::default(), jp_live_vars: MutMap::default(),
local_context: LocalContext::default(), local_context: LocalContext::default(),
function_params: MutMap::default(), param_map,
} }
} }
@ -253,56 +264,13 @@ impl<'a> Context<'a> {
self.arena.alloc(Stmt::Dec(symbol, stmt)) self.arena.alloc(Stmt::Dec(symbol, stmt))
} }
fn add_inc_before_consume_all_help<F>(
&self,
xs: &[Symbol],
consume_param_pred: F,
mut b: &'a Stmt<'a>,
live_vars_after: &LiveVarSet,
) -> &'a Stmt<'a>
where
F: Fn(usize) -> bool + Clone,
{
for (i, x) in xs.iter().enumerate() {
let info = self.get_var_info(*x);
if !info.reference || !is_first_occurence(xs, i) {
// do nothing
} else {
// number of times the argument is used (in the body?)
let num_consumptions = get_num_consumptions(*x, xs, consume_param_pred.clone());
// `x` is not a variable that must be consumed by the current procedure
// `x` is live after executing instruction
// `x` is used in a position that is passed as a borrow reference
let lives_on = !info.consume
|| live_vars_after.contains(x)
|| is_borrow_param_help(*x, xs, consume_param_pred.clone());
let num_incs = if lives_on {
num_consumptions
} else {
num_consumptions - 1
};
// Lean can increment by more than 1 at once. Is that needed?
debug_assert!(num_incs <= 1);
if num_incs == 1 {
b = self.add_inc(*x, b);
}
}
}
b
}
fn add_inc_before_consume_all( fn add_inc_before_consume_all(
&self, &self,
xs: &[Symbol], xs: &[Symbol],
b: &'a Stmt<'a>, b: &'a Stmt<'a>,
live_vars_after: &LiveVarSet, live_vars_after: &LiveVarSet,
) -> &'a Stmt<'a> { ) -> &'a Stmt<'a> {
self.add_inc_before_consume_all_help(xs, |_: usize| true, b, live_vars_after) self.add_inc_before_help(xs, |_: usize| true, b, live_vars_after)
} }
fn add_inc_before_help<F>( fn add_inc_before_help<F>(
@ -321,11 +289,17 @@ impl<'a> Context<'a> {
// do nothing // do nothing
} else { } else {
let num_consumptions = get_num_consumptions(*x, xs, consume_param_pred.clone()); // number of times the argument is used let num_consumptions = get_num_consumptions(*x, xs, consume_param_pred.clone()); // number of times the argument is used
let num_incs = if !info.consume || // `x` is not a variable that must be consumed by the current procedure
live_vars_after.contains(x) || // `x` is live after executing instruction // `x` is not a variable that must be consumed by the current procedure
is_borrow_param_help( *x ,xs, consume_param_pred.clone()) let need_not_consume = !info.consume;
// `x` is live after executing instruction
let is_live_after = live_vars_after.contains(x);
// `x` is used in a position that is passed as a borrow reference // `x` is used in a position that is passed as a borrow reference
{ let is_borrowed = is_borrow_param_help(*x, xs, consume_param_pred.clone());
let num_incs = if need_not_consume || is_live_after || is_borrowed {
num_consumptions num_consumptions
} else { } else {
num_consumptions - 1 num_consumptions - 1
@ -352,7 +326,7 @@ impl<'a> Context<'a> {
// default to owned arguments // default to owned arguments
let pred = |i: usize| match ps.get(i) { let pred = |i: usize| match ps.get(i) {
Some(param) => !param.borrow, Some(param) => !param.borrow,
None => true, None => unreachable!("or?"),
}; };
self.add_inc_before_help(xs, pred, b, live_vars_after) self.add_inc_before_help(xs, pred, b, live_vars_after)
} }
@ -383,10 +357,10 @@ impl<'a> Context<'a> {
b_live_vars: &LiveVarSet, b_live_vars: &LiveVarSet,
) -> &'a Stmt<'a> { ) -> &'a Stmt<'a> {
for (i, x) in xs.iter().enumerate() { for (i, x) in xs.iter().enumerate() {
/* We must add a `dec` if `x` must be consumed, it is alive after the application, // We must add a `dec` if `x` must be consumed, it is alive after the application,
and it has been borrowed by the application. // and it has been borrowed by the application.
Remark: `x` may occur multiple times in the application (e.g., `f x y x`). // Remark: `x` may occur multiple times in the application (e.g., `f x y x`).
This is why we check whether it is the first occurrence. */ // This is why we check whether it is the first occurrence.
if self.must_consume(*x) if self.must_consume(*x)
&& is_first_occurence(xs, i) && is_first_occurence(xs, i)
&& is_borrow_param(*x, xs, ps) && is_borrow_param(*x, xs, ps)
@ -399,6 +373,31 @@ impl<'a> Context<'a> {
b b
} }
fn add_dec_after_lowlevel(
&self,
xs: &[Symbol],
ps: &[bool],
mut b: &'a Stmt<'a>,
b_live_vars: &LiveVarSet,
) -> &'a Stmt<'a> {
for (i, (x, is_borrow)) in xs.iter().zip(ps.iter()).enumerate() {
/* We must add a `dec` if `x` must be consumed, it is alive after the application,
and it has been borrowed by the application.
Remark: `x` may occur multiple times in the application (e.g., `f x y x`).
This is why we check whether it is the first occurrence. */
if self.must_consume(*x)
&& is_first_occurence(xs, i)
&& *is_borrow
&& !b_live_vars.contains(x)
{
b = self.add_dec(*x, b);
}
}
b
}
#[allow(clippy::many_single_char_names)] #[allow(clippy::many_single_char_names)]
fn visit_variable_declaration( fn visit_variable_declaration(
&self, &self,
@ -432,54 +431,37 @@ impl<'a> Context<'a> {
self.arena.alloc(Stmt::Let(z, v, l, b)) self.arena.alloc(Stmt::Let(z, v, l, b))
} }
RunLowLevel(_, _) => { RunLowLevel(op, args) => {
// THEORY: runlowlevel only occurs let ps = crate::borrow::lowlevel_borrow_signature(self.arena, op);
// let b = self.add_dec_after_lowlevel(args, ps, b, b_live_vars);
// - in a custom hard-coded function
// - when we insert them as compiler authors
//
// if we're carefule to only use RunLowLevel for non-rc'd types
// (e.g. when building a cond/switch, we check equality on integers, and to boolean and)
// then RunLowLevel should not change in any way the refcounts.
// let b = self.add_dec_after_application(ys, ps, b, b_live_vars);
self.arena.alloc(Stmt::Let(z, v, l, b)) self.arena.alloc(Stmt::Let(z, v, l, b))
} }
FunctionCall { FunctionCall {
args: ys, args: ys,
call_type,
arg_layouts, arg_layouts,
call_type,
.. ..
} => { } => {
// this is where the borrow signature would come in // get the borrow signature
//let ps := (getDecl ctx f).params; let ps = match self.param_map.get_symbol(call_type.get_inner()) {
use crate::ir::CallType; Some(slice) => slice,
use crate::layout::Builtin; None => Vec::from_iter_in(
let symbol = match call_type { arg_layouts.iter().cloned().map(|layout| Param {
CallType::ByName(s) => s, symbol: Symbol::UNDERSCORE,
CallType::ByPointer(s) => s, borrow: false,
layout,
}),
self.arena,
)
.into_bump_slice(),
}; };
let ps = Vec::from_iter_in(
arg_layouts.iter().map(|layout| {
let borrow = match layout {
Layout::Builtin(Builtin::List(_, _)) => true,
_ => false,
};
Param {
symbol,
borrow,
layout: layout.clone(),
}
}),
self.arena,
)
.into_bump_slice();
let b = self.add_dec_after_application(ys, ps, b, b_live_vars); let b = self.add_dec_after_application(ys, ps, b, b_live_vars);
self.arena.alloc(Stmt::Let(z, v, l, b)) let b = self.arena.alloc(Stmt::Let(z, v, l, b));
self.add_inc_before(ys, ps, b, b_live_vars)
} }
EmptyArray | FunctionPointer(_, _) | Literal(_) | RuntimeErrorFunction(_) => { EmptyArray | FunctionPointer(_, _) | Literal(_) | RuntimeErrorFunction(_) => {
@ -495,13 +477,15 @@ impl<'a> Context<'a> {
fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self { fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self {
let mut ctx = self.clone(); let mut ctx = self.clone();
// TODO actually make these non-constant
// can this type be reference-counted at runtime? // can this type be reference-counted at runtime?
let reference = layout.contains_refcounted(); let reference = layout.contains_refcounted();
// is this value a constant? // is this value a constant?
let persistent = false; // TODO do function pointers also fall into this category?
let persistent = match expr {
Expr::FunctionCall { args, .. } => args.is_empty(),
_ => false,
};
// must this value be consumed? // must this value be consumed?
let consume = consume_expr(&ctx.vars, expr); let consume = consume_expr(&ctx.vars, expr);
@ -518,9 +502,6 @@ impl<'a> Context<'a> {
} }
fn update_var_info_with_params(&self, ps: &[Param]) -> Self { fn update_var_info_with_params(&self, ps: &[Param]) -> Self {
//def updateVarInfoWithParams (ctx : Context) (ps : Array Param) : Context :=
//let m := ps.foldl (fun (m : VarMap) p => m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }) ctx.varMap;
//{ ctx with varMap := m }
let mut ctx = self.clone(); let mut ctx = self.clone();
for p in ps.iter() { for p in ps.iter() {
@ -535,8 +516,13 @@ impl<'a> Context<'a> {
ctx ctx
} }
/* Add `dec` instructions for parameters that are references, are not alive in `b`, and are not borrow. // Add `dec` instructions for parameters that are
That is, we must make sure these parameters are consumed. */ //
// - references
// - not alive in `b`
// - not borrow.
//
// That is, we must make sure these parameters are consumed.
fn add_dec_for_dead_params( fn add_dec_for_dead_params(
&self, &self,
ps: &[Param<'a>], ps: &[Param<'a>],
@ -619,25 +605,20 @@ impl<'a> Context<'a> {
Join { Join {
id: j, id: j,
parameters: xs, parameters: _,
remainder: b, remainder: b,
continuation: v, continuation: v,
} => { } => {
let xs = *xs; // get the parameters with borrow signature
let xs = self.param_map.get_join_point(*j);
let v_orig = v;
// NOTE deviation from lean, insert into local context
let mut ctx = self.clone();
ctx.local_context.join_points.insert(*j, (xs, v_orig));
let (v, v_live_vars) = { let (v, v_live_vars) = {
let ctx = ctx.update_var_info_with_params(xs); let ctx = self.update_var_info_with_params(xs);
ctx.visit_stmt(v) ctx.visit_stmt(v)
}; };
let mut ctx = self.clone();
let v = ctx.add_dec_for_dead_params(xs, v, &v_live_vars); let v = ctx.add_dec_for_dead_params(xs, v, &v_live_vars);
let mut ctx = ctx.clone();
update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars); update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars);
@ -673,7 +654,10 @@ impl<'a> Context<'a> {
Some(vars) => vars, Some(vars) => vars,
None => &empty, None => &empty,
}; };
let ps = self.local_context.join_points.get(j).unwrap().0; // TODO use borrow signature here?
let ps = self.param_map.get_join_point(*j);
// let ps = self.local_context.join_points.get(j).unwrap().0;
let b = self.add_inc_before(xs, ps, stmt, j_live_vars); let b = self.add_inc_before(xs, ps, stmt, j_live_vars);
let b_live_vars = collect_stmt(b, &self.jp_live_vars, MutSet::default()); let b_live_vars = collect_stmt(b, &self.jp_live_vars, MutSet::default());
@ -796,8 +780,15 @@ pub fn collect_stmt(
collect_stmt(cont, jp_live_vars, vars) collect_stmt(cont, jp_live_vars, vars)
} }
Jump(_, arguments) => { Jump(id, arguments) => {
vars.extend(arguments.iter().copied()); vars.extend(arguments.iter().copied());
// NOTE deviation from Lean
// we fall through when no join point is available
if let Some(jvars) = jp_live_vars.get(id) {
vars.extend(jvars);
}
vars vars
} }
@ -866,8 +857,13 @@ fn update_jp_live_vars(j: JoinPointId, ys: &[Param], v: &Stmt<'_>, m: &mut JPLiv
m.insert(j, j_live_vars); m.insert(j, j_live_vars);
} }
pub fn visit_declaration<'a>(arena: &'a Bump, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { /// used to process the main function in the repl
let ctx = Context::new(arena); pub fn visit_declaration<'a>(
arena: &'a Bump,
param_map: &'a ParamMap<'a>,
stmt: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
let ctx = Context::new(arena, param_map);
let params = &[] as &[_]; let params = &[] as &[_];
let ctx = ctx.update_var_info_with_params(params); let ctx = ctx.update_var_info_with_params(params);
@ -875,23 +871,21 @@ pub fn visit_declaration<'a>(arena: &'a Bump, stmt: &'a Stmt<'a>) -> &'a Stmt<'a
ctx.add_dec_for_dead_params(params, b, &b_live_vars) ctx.add_dec_for_dead_params(params, b, &b_live_vars)
} }
pub fn visit_proc<'a>(arena: &'a Bump, proc: &mut Proc<'a>) { pub fn visit_proc<'a>(arena: &'a Bump, param_map: &'a ParamMap<'a>, proc: &mut Proc<'a>) {
let ctx = Context::new(arena); let ctx = Context::new(arena, param_map);
if proc.name.is_builtin() { let params = match param_map.get_symbol(proc.name) {
// we must take care of our own refcounting in builtins Some(slice) => slice,
return; None => Vec::from_iter_in(
} proc.args.iter().cloned().map(|(layout, symbol)| Param {
symbol,
let params = Vec::from_iter_in( borrow: false,
proc.args.iter().map(|(layout, symbol)| Param { layout,
symbol: *symbol, }),
layout: layout.clone(), arena,
borrow: layout.contains_refcounted(), )
}), .into_bump_slice(),
arena, };
)
.into_bump_slice();
let stmt = arena.alloc(proc.body.clone()); let stmt = arena.alloc(proc.body.clone());
let ctx = ctx.update_var_info_with_params(params); let ctx = ctx.update_var_info_with_params(params);

View file

@ -23,7 +23,7 @@ pub struct PartialProc<'a> {
pub annotation: Variable, pub annotation: Variable,
pub pattern_symbols: Vec<'a, Symbol>, pub pattern_symbols: Vec<'a, Symbol>,
pub body: roc_can::expr::Expr, pub body: roc_can::expr::Expr,
pub is_tail_recursive: bool, pub is_self_recursive: bool,
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -40,7 +40,13 @@ pub struct Proc<'a> {
pub body: Stmt<'a>, pub body: Stmt<'a>,
pub closes_over: Layout<'a>, pub closes_over: Layout<'a>,
pub ret_layout: Layout<'a>, pub ret_layout: Layout<'a>,
pub is_tail_recursive: bool, pub is_self_recursive: SelfRecursive,
}
#[derive(Clone, Debug, PartialEq)]
pub enum SelfRecursive {
NotSelfRecursive,
SelfRecursive(JoinPointId),
} }
impl<'a> Proc<'a> { impl<'a> Proc<'a> {
@ -111,15 +117,74 @@ impl<'a> Procs<'a> {
for (key, in_prog_proc) in self.specialized.into_iter() { for (key, in_prog_proc) in self.specialized.into_iter() {
match in_prog_proc { match in_prog_proc {
InProgress => unreachable!("The procedure {:?} should have be done by now", key), InProgress => unreachable!("The procedure {:?} should have be done by now", key),
Done(mut proc) => { Done(proc) => {
crate::inc_dec::visit_proc(arena, &mut proc);
result.insert(key, proc); result.insert(key, proc);
} }
} }
} }
for (_, proc) in result.iter_mut() {
use self::SelfRecursive::*;
if let SelfRecursive(id) = proc.is_self_recursive {
proc.body = crate::tail_recursion::make_tail_recursive(
arena,
id,
proc.name,
proc.body.clone(),
proc.args,
);
}
}
let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result));
for (_, proc) in result.iter_mut() {
crate::inc_dec::visit_proc(arena, borrow_params, proc);
}
result result
} }
pub fn get_specialized_procs_help(
self,
arena: &'a Bump,
) -> (
MutMap<(Symbol, Layout<'a>), Proc<'a>>,
&'a crate::borrow::ParamMap<'a>,
) {
let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
for (key, in_prog_proc) in self.specialized.into_iter() {
match in_prog_proc {
InProgress => unreachable!("The procedure {:?} should have be done by now", key),
Done(proc) => {
result.insert(key, proc);
}
}
}
for (_, proc) in result.iter_mut() {
use self::SelfRecursive::*;
if let SelfRecursive(id) = proc.is_self_recursive {
proc.body = crate::tail_recursion::make_tail_recursive(
arena,
id,
proc.name,
proc.body.clone(),
proc.args,
);
}
}
let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result));
for (_, proc) in result.iter_mut() {
crate::inc_dec::visit_proc(arena, borrow_params, proc);
}
(result, borrow_params)
}
// TODO trim down these arguments! // TODO trim down these arguments!
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn insert_named( pub fn insert_named(
@ -130,7 +195,7 @@ impl<'a> Procs<'a> {
annotation: Variable, annotation: Variable,
loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>, loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>,
loc_body: Located<roc_can::expr::Expr>, loc_body: Located<roc_can::expr::Expr>,
is_tail_recursive: bool, is_self_recursive: bool,
ret_var: Variable, ret_var: Variable,
) { ) {
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
@ -145,7 +210,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive, is_self_recursive,
}, },
); );
} }
@ -179,7 +244,7 @@ impl<'a> Procs<'a> {
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
) -> Result<Layout<'a>, RuntimeError> { ) -> Result<Layout<'a>, RuntimeError> {
// anonymous functions cannot reference themselves, therefore cannot be tail-recursive // anonymous functions cannot reference themselves, therefore cannot be tail-recursive
let is_tail_recursive = false; let is_self_recursive = false;
match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) { match patterns_to_when(env, layout_cache, loc_args, ret_var, loc_body) {
Ok((pattern_vars, pattern_symbols, body)) => { Ok((pattern_vars, pattern_symbols, body)) => {
@ -219,7 +284,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive, is_self_recursive,
}, },
); );
} }
@ -229,7 +294,7 @@ impl<'a> Procs<'a> {
annotation, annotation,
pattern_symbols, pattern_symbols,
body: body.value, body: body.value,
is_tail_recursive, is_self_recursive,
}; };
// Mark this proc as in-progress, so if we're dealing with // Mark this proc as in-progress, so if we're dealing with
@ -459,6 +524,15 @@ pub enum CallType {
ByPointer(Symbol), ByPointer(Symbol),
} }
impl CallType {
pub fn get_inner(&self) -> Symbol {
match self {
CallType::ByName(s) => *s,
CallType::ByPointer(s) => *s,
}
}
}
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Expr<'a> { pub enum Expr<'a> {
Literal(Literal<'a>), Literal(Literal<'a>),
@ -1001,7 +1075,7 @@ fn specialize<'a>(
annotation, annotation,
pattern_symbols, pattern_symbols,
body, body,
is_tail_recursive, is_self_recursive,
} = partial_proc; } = partial_proc;
// unify the called function with the specialized signature, then specialize the function body // unify the called function with the specialized signature, then specialize the function body
@ -1031,9 +1105,6 @@ fn specialize<'a>(
let proc_args = proc_args.into_bump_slice(); let proc_args = proc_args.into_bump_slice();
let specialized_body =
crate::tail_recursion::make_tail_recursive(env, proc_name, specialized_body, proc_args);
let ret_layout = layout_cache let ret_layout = layout_cache
.from_var(&env.arena, ret_var, env.subs) .from_var(&env.arena, ret_var, env.subs)
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err));
@ -1041,13 +1112,19 @@ fn specialize<'a>(
// TODO WRONG // TODO WRONG
let closes_over_layout = Layout::Struct(&[]); let closes_over_layout = Layout::Struct(&[]);
let recursivity = if is_self_recursive {
SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol()))
} else {
SelfRecursive::NotSelfRecursive
};
let proc = Proc { let proc = Proc {
name: proc_name, name: proc_name,
args: proc_args, args: proc_args,
body: specialized_body, body: specialized_body,
closes_over: closes_over_layout, closes_over: closes_over_layout,
ret_layout, ret_layout,
is_tail_recursive, is_self_recursive: recursivity,
}; };
Ok(proc) Ok(proc)
@ -1109,8 +1186,8 @@ pub fn with_hole<'a>(
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive = let is_self_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive); !matches!(recursivity, roc_can::expr::Recursive::NotRecursive);
procs.insert_named( procs.insert_named(
env, env,
@ -1119,7 +1196,7 @@ pub fn with_hole<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive, is_self_recursive,
ret_var, ret_var,
); );
@ -1193,8 +1270,8 @@ pub fn with_hole<'a>(
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive = let is_self_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive); !matches!(recursivity, roc_can::expr::Recursive::NotRecursive);
procs.insert_named( procs.insert_named(
env, env,
@ -1203,7 +1280,7 @@ pub fn with_hole<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive, is_self_recursive,
ret_var, ret_var,
); );
@ -2060,8 +2137,8 @@ pub fn from_can<'a>(
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive = let is_self_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive); !matches!(recursivity, roc_can::expr::Recursive::NotRecursive);
procs.insert_named( procs.insert_named(
env, env,
@ -2070,7 +2147,7 @@ pub fn from_can<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive, is_self_recursive,
ret_var, ret_var,
); );
@ -2096,8 +2173,8 @@ pub fn from_can<'a>(
let (loc_body, ret_var) = *boxed_body; let (loc_body, ret_var) = *boxed_body;
let is_tail_recursive = let is_self_recursive =
matches!(recursivity, roc_can::expr::Recursive::TailRecursive); !matches!(recursivity, roc_can::expr::Recursive::NotRecursive);
procs.insert_named( procs.insert_named(
env, env,
@ -2106,7 +2183,7 @@ pub fn from_can<'a>(
ann, ann,
loc_args, loc_args,
loc_body, loc_body,
is_tail_recursive, is_self_recursive,
ret_var, ret_var,
); );

View file

@ -22,6 +22,7 @@ pub enum Layout<'a> {
Builtin(Builtin<'a>), Builtin(Builtin<'a>),
Struct(&'a [Layout<'a>]), Struct(&'a [Layout<'a>]),
Union(&'a [&'a [Layout<'a>]]), Union(&'a [&'a [Layout<'a>]]),
RecursiveUnion(&'a [&'a [Layout<'a>]]),
/// A function. The types of its arguments, then the type of its return value. /// A function. The types of its arguments, then the type of its return value.
FunctionPointer(&'a [Layout<'a>], &'a Layout<'a>), FunctionPointer(&'a [Layout<'a>], &'a Layout<'a>),
Pointer(&'a Layout<'a>), Pointer(&'a Layout<'a>),
@ -96,6 +97,10 @@ impl<'a> Layout<'a> {
Union(tags) => tags Union(tags) => tags
.iter() .iter()
.all(|tag_layout| tag_layout.iter().all(|field| field.safe_to_memcpy())), .all(|tag_layout| tag_layout.iter().all(|field| field.safe_to_memcpy())),
RecursiveUnion(_) => {
// a recursive union will always contain a pointer, and are thus not safe to memcpy
false
}
FunctionPointer(_, _) => { FunctionPointer(_, _) => {
// Function pointers are immutable and can always be safely copied // Function pointers are immutable and can always be safely copied
true true
@ -138,6 +143,16 @@ impl<'a> Layout<'a> {
}) })
.max() .max()
.unwrap_or_default(), .unwrap_or_default(),
RecursiveUnion(fields) => fields
.iter()
.map(|tag_layout| {
tag_layout
.iter()
.map(|field| field.stack_size(pointer_size))
.sum()
})
.max()
.unwrap_or_default(),
FunctionPointer(_, _) => pointer_size, FunctionPointer(_, _) => pointer_size,
Pointer(_) => pointer_size, Pointer(_) => pointer_size,
} }
@ -146,6 +161,7 @@ impl<'a> Layout<'a> {
pub fn is_refcounted(&self) -> bool { pub fn is_refcounted(&self) -> bool {
match self { match self {
Layout::Builtin(Builtin::List(_, _)) => true, Layout::Builtin(Builtin::List(_, _)) => true,
Layout::RecursiveUnion(_) => true,
_ => false, _ => false,
} }
} }
@ -164,6 +180,7 @@ impl<'a> Layout<'a> {
.map(|ls| ls.iter()) .map(|ls| ls.iter())
.flatten() .flatten()
.any(|f| f.is_refcounted()), .any(|f| f.is_refcounted()),
RecursiveUnion(_) => true,
FunctionPointer(_, _) | Pointer(_) => false, FunctionPointer(_, _) | Pointer(_) => false,
} }
} }
@ -406,8 +423,41 @@ fn layout_from_flat_type<'a>(
Ok(layout_from_tag_union(arena, tags, subs)) Ok(layout_from_tag_union(arena, tags, subs))
} }
RecursiveTagUnion(_rec_var, _tags, _ext_var) => { RecursiveTagUnion(_rec_var, _tags, ext_var) => {
panic!("TODO make Layout for empty RecursiveTagUnion"); debug_assert!(ext_var_is_empty_tag_union(subs, ext_var));
// some observations
//
// * recursive tag unions are always recursive
// * therefore at least one tag has a pointer (non-zero sized) field
// * they must (to be instantiated) have 2 or more tags
//
// That means none of the optimizations for enums or single tag tag unions apply
// let rec_var = subs.get_root_key_without_compacting(rec_var);
// let mut tag_layouts = Vec::with_capacity_in(tags.len(), arena);
//
// // tags: MutMap<TagName, std::vec::Vec<Variable>>,
// for (_name, variables) in tags {
// let mut tag_layout = Vec::with_capacity_in(variables.len(), arena);
//
// for var in variables {
// // TODO does this still cause problems with mutually recursive unions?
// if rec_var == subs.get_root_key_without_compacting(var) {
// // TODO make this a pointer?
// continue;
// }
//
// let var_content = subs.get_without_compacting(var).content;
//
// tag_layout.push(Layout::new(arena, var_content, subs)?);
// }
//
// tag_layouts.push(tag_layout.into_bump_slice());
// }
//
// Ok(Layout::RecursiveUnion(tag_layouts.into_bump_slice()))
Ok(Layout::RecursiveUnion(&[]))
} }
EmptyTagUnion => { EmptyTagUnion => {
panic!("TODO make Layout for empty Tag Union"); panic!("TODO make Layout for empty Tag Union");

View file

@ -11,6 +11,7 @@
// re-enable this when working on performance optimizations than have it block PRs. // re-enable this when working on performance optimizations than have it block PRs.
#![allow(clippy::large_enum_variant)] #![allow(clippy::large_enum_variant)]
pub mod borrow;
pub mod inc_dec; pub mod inc_dec;
pub mod ir; pub mod ir;
pub mod layout; pub mod layout;

View file

@ -1,19 +1,40 @@
use crate::ir::{CallType, Env, Expr, JoinPointId, Param, Stmt}; use crate::ir::{CallType, Expr, JoinPointId, Param, Stmt};
use crate::layout::Layout; use crate::layout::Layout;
use bumpalo::collections::Vec; use bumpalo::collections::Vec;
use bumpalo::Bump; use bumpalo::Bump;
use roc_module::symbol::Symbol; use roc_module::symbol::Symbol;
/// Make tail calls into loops (using join points)
///
/// e.g.
///
/// > factorial n accum = if n == 1 then accum else factorial (n - 1) (n * accum)
///
/// becomes
///
/// ```elm
/// factorial n1 accum1 =
/// let joinpoint j n accum =
/// if n == 1 then
/// accum
/// else
/// jump j (n - 1) (n * accum)
///
/// in
/// jump j n1 accum1
/// ```
///
/// This will effectively compile into a loop in llvm, and
/// won't grow the call stack for each iteration
pub fn make_tail_recursive<'a>( pub fn make_tail_recursive<'a>(
env: &mut Env<'a, '_>, arena: &'a Bump,
id: JoinPointId,
needle: Symbol, needle: Symbol,
stmt: Stmt<'a>, stmt: Stmt<'a>,
args: &'a [(Layout<'a>, Symbol)], args: &'a [(Layout<'a>, Symbol)],
) -> Stmt<'a> { ) -> Stmt<'a> {
let id = JoinPointId(env.unique_symbol()); let alloced = arena.alloc(stmt);
match insert_jumps(arena, alloced, id, needle) {
let alloced = env.arena.alloc(stmt);
match insert_jumps(env.arena, alloced, id, needle) {
None => alloced.clone(), None => alloced.clone(),
Some(new) => { Some(new) => {
// jumps were inserted, we must now add a join point // jumps were inserted, we must now add a join point
@ -24,13 +45,14 @@ pub fn make_tail_recursive<'a>(
layout: layout.clone(), layout: layout.clone(),
borrow: true, borrow: true,
}), }),
env.arena, arena,
) )
.into_bump_slice(); .into_bump_slice();
let args = Vec::from_iter_in(args.iter().map(|t| t.1), env.arena).into_bump_slice(); // TODO could this be &[]?
let args = Vec::from_iter_in(args.iter().map(|t| t.1), arena).into_bump_slice();
let jump = env.arena.alloc(Stmt::Jump(id, args)); let jump = arena.alloc(Stmt::Jump(id, args));
Stmt::Join { Stmt::Join {
id, id,
@ -185,7 +207,6 @@ fn insert_jumps<'a>(
None None
} }
} }
Ret(_) => None,
Inc(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) { Inc(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Inc(*symbol, cont))), Some(cont) => Some(arena.alloc(Inc(*symbol, cont))),
None => None, None => None,
@ -195,6 +216,7 @@ fn insert_jumps<'a>(
None => None, None => None,
}, },
Ret(_) => None,
Jump(_, _) => None, Jump(_, _) => None,
RuntimeError(_) => None, RuntimeError(_) => None,
} }

View file

@ -66,17 +66,18 @@ mod test_mono {
// let mono_expr = Expr::new(&mut mono_env, loc_expr.value, &mut procs); // let mono_expr = Expr::new(&mut mono_env, loc_expr.value, &mut procs);
let procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut LayoutCache::default()); let procs = roc_mono::ir::specialize_all(&mut mono_env, procs, &mut LayoutCache::default());
// apply inc/dec
let stmt = mono_env.arena.alloc(ir_expr);
let ir_expr = roc_mono::inc_dec::visit_declaration(mono_env.arena, stmt);
assert_eq!( assert_eq!(
procs.runtime_errors, procs.runtime_errors,
roc_collections::all::MutMap::default() roc_collections::all::MutMap::default()
); );
let (procs, param_map) = procs.get_specialized_procs_help(mono_env.arena);
// apply inc/dec
let stmt = mono_env.arena.alloc(ir_expr);
let ir_expr = roc_mono::inc_dec::visit_declaration(mono_env.arena, param_map, stmt);
let mut procs_string = procs let mut procs_string = procs
.get_specialized_procs(mono_env.arena)
.values() .values()
.map(|proc| proc.to_pretty(200)) .map(|proc| proc.to_pretty(200))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -94,6 +95,7 @@ mod test_mono {
let result_lines = result.split("\n").collect::<Vec<&str>>(); let result_lines = result.split("\n").collect::<Vec<&str>>();
assert_eq!(expected_lines, result_lines); assert_eq!(expected_lines, result_lines);
//assert_eq!(0, 1);
} }
} }
@ -380,27 +382,35 @@ mod test_mono {
fn guard_pattern_true() { fn guard_pattern_true() {
compiles_to_ir( compiles_to_ir(
r#" r#"
when 2 is main = \{} ->
2 if False -> 42 when 2 is
_ -> 0 2 if False -> 42
_ -> 0
main {}
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.0 = 2i64; procedure Test.0 (Test.2):
let Test.6 = true; let Test.5 = 2i64;
let Test.7 = 2i64; let Test.11 = true;
let Test.10 = lowlevel Eq Test.7 Test.0; let Test.12 = 2i64;
let Test.8 = lowlevel And Test.10 Test.6; let Test.15 = lowlevel Eq Test.12 Test.5;
let Test.3 = false; let Test.13 = lowlevel And Test.15 Test.11;
jump Test.2 Test.3; let Test.8 = false;
joinpoint Test.2 Test.9: jump Test.7 Test.8;
let Test.5 = lowlevel And Test.9 Test.8; joinpoint Test.7 Test.14:
if Test.5 then let Test.10 = lowlevel And Test.14 Test.13;
let Test.1 = 42i64; if Test.10 then
ret Test.1; let Test.6 = 42i64;
else ret Test.6;
let Test.4 = 0i64; else
ret Test.4; let Test.9 = 0i64;
ret Test.9;
let Test.4 = Struct {};
let Test.3 = CallByName Test.0 Test.4;
ret Test.3;
"# "#
), ),
) )
@ -539,7 +549,6 @@ mod test_mono {
let Test.6 = 2i64; let Test.6 = 2i64;
let Test.4 = Array [Test.5, Test.6]; let Test.4 = Array [Test.5, Test.6];
let Test.3 = CallByName Test.0 Test.4; let Test.3 = CallByName Test.0 Test.4;
dec Test.4;
ret Test.3; ret Test.3;
"# "#
), ),
@ -548,6 +557,8 @@ mod test_mono {
#[test] #[test]
fn list_append() { fn list_append() {
// TODO this leaks at the moment
// ListAppend needs to decrement its arguments
compiles_to_ir( compiles_to_ir(
r#" r#"
List.append [1] 2 List.append [1] 2
@ -562,7 +573,6 @@ mod test_mono {
let Test.1 = Array [Test.3]; let Test.1 = Array [Test.3];
let Test.2 = 2i64; let Test.2 = 2i64;
let Test.0 = CallByName List.5 Test.1 Test.2; let Test.0 = CallByName List.5 Test.1 Test.2;
dec Test.1;
ret Test.0; ret Test.0;
"# "#
), ),
@ -581,16 +591,16 @@ mod test_mono {
indoc!( indoc!(
r#" r#"
procedure Num.14 (#Attr.2, #Attr.3): procedure Num.14 (#Attr.2, #Attr.3):
let Test.13 = lowlevel NumAdd #Attr.2 #Attr.3; let Test.11 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.13; ret Test.11;
procedure List.7 (#Attr.2): procedure List.7 (#Attr.2):
let Test.9 = lowlevel ListLen #Attr.2; let Test.9 = lowlevel ListLen #Attr.2;
ret Test.9; ret Test.9;
procedure List.7 (#Attr.2): procedure List.7 (#Attr.2):
let Test.11 = lowlevel ListLen #Attr.2; let Test.10 = lowlevel ListLen #Attr.2;
ret Test.11; ret Test.10;
let Test.8 = 1f64; let Test.8 = 1f64;
let Test.1 = Array [Test.8]; let Test.1 = Array [Test.8];
@ -613,35 +623,43 @@ mod test_mono {
fn when_joinpoint() { fn when_joinpoint() {
compiles_to_ir( compiles_to_ir(
r#" r#"
x : [ Red, White, Blue ] main = \{} ->
x = Blue x : [ Red, White, Blue ]
x = Blue
y = y =
when x is when x is
Red -> 1 Red -> 1
White -> 2 White -> 2
Blue -> 3 Blue -> 3
y y
main {}
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.0 = 0u8; procedure Test.0 (Test.4):
switch Test.0: let Test.2 = 0u8;
case 1: switch Test.2:
let Test.4 = 1i64; case 1:
jump Test.3 Test.4; let Test.9 = 1i64;
jump Test.8 Test.9;
case 2:
let Test.10 = 2i64;
jump Test.8 Test.10;
default:
let Test.11 = 3i64;
jump Test.8 Test.11;
joinpoint Test.8 Test.3:
ret Test.3;
case 2: let Test.6 = Struct {};
let Test.5 = 2i64; let Test.5 = CallByName Test.0 Test.6;
jump Test.3 Test.5; ret Test.5;
default:
let Test.6 = 3i64;
jump Test.3 Test.6;
joinpoint Test.3 Test.1:
ret Test.1;
"# "#
), ),
) )
@ -704,43 +722,51 @@ mod test_mono {
fn when_on_result() { fn when_on_result() {
compiles_to_ir( compiles_to_ir(
r#" r#"
x : Result Int Int main = \{} ->
x = Ok 2 x : Result Int Int
x = Ok 2
y = y =
when x is when x is
Ok 3 -> 1 Ok 3 -> 1
Ok _ -> 2 Ok _ -> 2
Err _ -> 3 Err _ -> 3
y y
main {}
"#, "#,
indoc!( indoc!(
r#" r#"
let Test.17 = 1i64; procedure Test.0 (Test.4):
let Test.18 = 2i64; let Test.22 = 1i64;
let Test.0 = Ok Test.17 Test.18; let Test.23 = 2i64;
let Test.13 = true; let Test.2 = Ok Test.22 Test.23;
let Test.15 = Index 0 Test.0; let Test.18 = true;
let Test.14 = 1i64; let Test.20 = Index 0 Test.2;
let Test.16 = lowlevel Eq Test.14 Test.15; let Test.19 = 1i64;
let Test.12 = lowlevel And Test.16 Test.13; let Test.21 = lowlevel Eq Test.19 Test.20;
if Test.12 then let Test.17 = lowlevel And Test.21 Test.18;
let Test.8 = true; if Test.17 then
let Test.9 = 3i64; let Test.13 = true;
let Test.10 = Index 0 Test.0; let Test.14 = 3i64;
let Test.11 = lowlevel Eq Test.9 Test.10; let Test.15 = Index 0 Test.2;
let Test.7 = lowlevel And Test.11 Test.8; let Test.16 = lowlevel Eq Test.14 Test.15;
if Test.7 then let Test.12 = lowlevel And Test.16 Test.13;
let Test.4 = 1i64; if Test.12 then
jump Test.3 Test.4; let Test.9 = 1i64;
jump Test.8 Test.9;
else
let Test.10 = 2i64;
jump Test.8 Test.10;
else else
let Test.5 = 2i64; let Test.11 = 3i64;
jump Test.3 Test.5; jump Test.8 Test.11;
else joinpoint Test.8 Test.3:
let Test.6 = 3i64; ret Test.3;
jump Test.3 Test.6;
joinpoint Test.3 Test.1: let Test.6 = Struct {};
ret Test.1; let Test.5 = CallByName Test.0 Test.6;
ret Test.5;
"# "#
), ),
) )
@ -796,30 +822,38 @@ mod test_mono {
compiles_to_ir( compiles_to_ir(
indoc!( indoc!(
r#" r#"
when 10 is main = \{} ->
x if x == 5 -> 0 when 10 is
_ -> 42 x if x == 5 -> 0
_ -> 42
main {}
"# "#
), ),
indoc!( indoc!(
r#" r#"
procedure Bool.5 (#Attr.2, #Attr.3): procedure Test.0 (Test.3):
let Test.10 = lowlevel Eq #Attr.2 #Attr.3; let Test.6 = 10i64;
ret Test.10; let Test.14 = true;
let Test.10 = 5i64;
let Test.9 = CallByName Bool.5 Test.6 Test.10;
jump Test.8 Test.9;
joinpoint Test.8 Test.15:
let Test.13 = lowlevel And Test.15 Test.14;
if Test.13 then
let Test.7 = 0i64;
ret Test.7;
else
let Test.12 = 42i64;
ret Test.12;
let Test.1 = 10i64; procedure Bool.5 (#Attr.2, #Attr.3):
let Test.8 = true; let Test.11 = lowlevel Eq #Attr.2 #Attr.3;
let Test.5 = 5i64; ret Test.11;
let Test.4 = CallByName Bool.5 Test.1 Test.5;
jump Test.3 Test.4; let Test.5 = Struct {};
joinpoint Test.3 Test.9: let Test.4 = CallByName Test.0 Test.5;
let Test.7 = lowlevel And Test.9 Test.8; ret Test.4;
if Test.7 then
let Test.2 = 0i64;
ret Test.2;
else
let Test.6 = 42i64;
ret Test.6;
"# "#
), ),
) )
@ -905,12 +939,6 @@ mod test_mono {
), ),
indoc!( indoc!(
r#" r#"
procedure Test.1 (Test.3):
let Test.9 = 0i64;
let Test.10 = 0i64;
let Test.8 = CallByName List.4 Test.3 Test.9 Test.10;
ret Test.8;
procedure List.4 (#Attr.2, #Attr.3, #Attr.4): procedure List.4 (#Attr.2, #Attr.3, #Attr.4):
let Test.14 = lowlevel ListLen #Attr.2; let Test.14 = lowlevel ListLen #Attr.2;
let Test.12 = lowlevel NumLt #Attr.3 Test.14; let Test.12 = lowlevel NumLt #Attr.3 Test.14;
@ -920,12 +948,17 @@ mod test_mono {
else else
ret #Attr.2; ret #Attr.2;
procedure Test.1 (Test.3):
let Test.9 = 0i64;
let Test.10 = 0i64;
let Test.8 = CallByName List.4 Test.3 Test.9 Test.10;
ret Test.8;
let Test.5 = 1i64; let Test.5 = 1i64;
let Test.6 = 2i64; let Test.6 = 2i64;
let Test.7 = 3i64; let Test.7 = 3i64;
let Test.0 = Array [Test.5, Test.6, Test.7]; let Test.0 = Array [Test.5, Test.6, Test.7];
let Test.4 = CallByName Test.1 Test.0; let Test.4 = CallByName Test.1 Test.0;
dec Test.0;
ret Test.4; ret Test.4;
"# "#
), ),
@ -1066,7 +1099,8 @@ mod test_mono {
) )
} }
#[allow(dead_code)] #[ignore]
#[test]
fn quicksort_help() { fn quicksort_help() {
crate::helpers::with_larger_debug_stack(|| { crate::helpers::with_larger_debug_stack(|| {
compiles_to_ir( compiles_to_ir(
@ -1094,7 +1128,8 @@ mod test_mono {
}) })
} }
#[allow(dead_code)] #[ignore]
#[test]
fn quicksort_partition_help() { fn quicksort_partition_help() {
crate::helpers::with_larger_debug_stack(|| { crate::helpers::with_larger_debug_stack(|| {
compiles_to_ir( compiles_to_ir(
@ -1128,7 +1163,8 @@ mod test_mono {
}) })
} }
#[allow(dead_code)] #[ignore]
#[test]
fn quicksort_full() { fn quicksort_full() {
crate::helpers::with_larger_debug_stack(|| { crate::helpers::with_larger_debug_stack(|| {
compiles_to_ir( compiles_to_ir(
@ -1217,29 +1253,29 @@ mod test_mono {
"#, "#,
indoc!( indoc!(
r#" r#"
procedure Num.15 (#Attr.2, #Attr.3):
let Test.13 = lowlevel NumSub #Attr.2 #Attr.3;
ret Test.13;
procedure Test.0 (Test.2, Test.3): procedure Test.0 (Test.2, Test.3):
jump Test.20 Test.2 Test.3; jump Test.18 Test.2 Test.3;
joinpoint Test.20 Test.2 Test.3: joinpoint Test.18 Test.2 Test.3:
let Test.17 = true; let Test.15 = true;
let Test.18 = 0i64; let Test.16 = 0i64;
let Test.19 = lowlevel Eq Test.18 Test.2; let Test.17 = lowlevel Eq Test.16 Test.2;
let Test.16 = lowlevel And Test.19 Test.17; let Test.14 = lowlevel And Test.17 Test.15;
if Test.16 then if Test.14 then
ret Test.3; ret Test.3;
else else
let Test.13 = 1i64; let Test.12 = 1i64;
let Test.9 = CallByName Num.15 Test.2 Test.13; let Test.9 = CallByName Num.15 Test.2 Test.12;
let Test.10 = CallByName Num.16 Test.2 Test.3; let Test.10 = CallByName Num.16 Test.2 Test.3;
jump Test.20 Test.9 Test.10; jump Test.18 Test.9 Test.10;
procedure Num.16 (#Attr.2, #Attr.3): procedure Num.16 (#Attr.2, #Attr.3):
let Test.11 = lowlevel NumMul #Attr.2 #Attr.3; let Test.11 = lowlevel NumMul #Attr.2 #Attr.3;
ret Test.11; ret Test.11;
procedure Num.15 (#Attr.2, #Attr.3):
let Test.14 = lowlevel NumSub #Attr.2 #Attr.3;
ret Test.14;
let Test.5 = 10i64; let Test.5 = 10i64;
let Test.6 = 1i64; let Test.6 = 1i64;
let Test.4 = CallByName Test.0 Test.5 Test.6; let Test.4 = CallByName Test.0 Test.5 Test.6;
@ -1248,4 +1284,248 @@ mod test_mono {
), ),
) )
} }
#[test]
#[ignore]
fn is_nil() {
compiles_to_ir(
r#"
ConsList a : [ Cons a (ConsList a), Nil ]
isNil : ConsList a -> Bool
isNil = \list ->
when list is
Nil -> True
Cons _ _ -> False
isNil (Cons 0x2 Nil)
"#,
indoc!(
r#"
procedure Test.1 (Test.3):
let Test.13 = true;
let Test.15 = Index 0 Test.3;
let Test.14 = 1i64;
let Test.16 = lowlevel Eq Test.14 Test.15;
let Test.12 = lowlevel And Test.16 Test.13;
if Test.12 then
let Test.10 = true;
ret Test.10;
else
let Test.11 = false;
ret Test.11;
let Test.6 = 0i64;
let Test.7 = 2i64;
let Test.9 = 1i64;
let Test.8 = Nil Test.9;
let Test.5 = Cons Test.6 Test.7 Test.8;
let Test.4 = CallByName Test.1 Test.5;
ret Test.4;
"#
),
)
}
#[test]
#[ignore]
fn has_none() {
compiles_to_ir(
r#"
Maybe a : [ Just a, Nothing ]
ConsList a : [ Cons a (ConsList a), Nil ]
hasNone : ConsList (Maybe a) -> Bool
hasNone = \list ->
when list is
Nil -> False
Cons Nothing _ -> True
Cons (Just _) xs -> hasNone xs
hasNone (Cons (Just 3) Nil)
"#,
indoc!(
r#"
procedure Test.1 (Test.3):
let Test.13 = true;
let Test.15 = Index 0 Test.3;
let Test.14 = 1i64;
let Test.16 = lowlevel Eq Test.14 Test.15;
let Test.12 = lowlevel And Test.16 Test.13;
if Test.12 then
let Test.10 = true;
ret Test.10;
else
let Test.11 = false;
ret Test.11;
let Test.6 = 0i64;
let Test.7 = 2i64;
let Test.9 = 1i64;
let Test.8 = Nil Test.9;
let Test.5 = Cons Test.6 Test.7 Test.8;
let Test.4 = CallByName Test.1 Test.5;
ret Test.4;
"#
),
)
}
#[test]
fn mk_pair_of() {
compiles_to_ir(
r#"
mkPairOf = \x -> Pair x x
mkPairOf [1,2,3]
"#,
indoc!(
r#"
procedure Test.0 (Test.2):
inc Test.2;
let Test.8 = Struct {Test.2, Test.2};
ret Test.8;
let Test.5 = 1i64;
let Test.6 = 2i64;
let Test.7 = 3i64;
let Test.4 = Array [Test.5, Test.6, Test.7];
let Test.3 = CallByName Test.0 Test.4;
ret Test.3;
"#
),
)
}
#[test]
fn fst() {
compiles_to_ir(
r#"
fst = \x, y -> x
fst [1,2,3] [3,2,1]
"#,
indoc!(
r#"
procedure Test.0 (Test.2, Test.3):
inc Test.2;
ret Test.2;
let Test.10 = 1i64;
let Test.11 = 2i64;
let Test.12 = 3i64;
let Test.5 = Array [Test.10, Test.11, Test.12];
let Test.7 = 3i64;
let Test.8 = 2i64;
let Test.9 = 1i64;
let Test.6 = Array [Test.7, Test.8, Test.9];
let Test.4 = CallByName Test.0 Test.5 Test.6;
dec Test.6;
dec Test.5;
ret Test.4;
"#
),
)
}
#[test]
fn list_cannot_update_inplace() {
compiles_to_ir(
indoc!(
r#"
x : List Int
x = [1,2,3]
add : List Int -> List Int
add = \y -> List.set y 0 0
List.len (add x) + List.len x
"#
),
indoc!(
r#"
procedure Num.14 (#Attr.2, #Attr.3):
let Test.19 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.19;
procedure Test.1 (Test.3):
let Test.13 = 0i64;
let Test.14 = 0i64;
let Test.12 = CallByName List.4 Test.3 Test.13 Test.14;
ret Test.12;
procedure List.4 (#Attr.2, #Attr.3, #Attr.4):
let Test.18 = lowlevel ListLen #Attr.2;
let Test.16 = lowlevel NumLt #Attr.3 Test.18;
if Test.16 then
let Test.17 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4;
ret Test.17;
else
ret #Attr.2;
procedure List.7 (#Attr.2):
let Test.11 = lowlevel ListLen #Attr.2;
ret Test.11;
let Test.8 = 1i64;
let Test.9 = 2i64;
let Test.10 = 3i64;
let Test.0 = Array [Test.8, Test.9, Test.10];
inc Test.0;
let Test.7 = CallByName Test.1 Test.0;
let Test.5 = CallByName List.7 Test.7;
dec Test.7;
let Test.6 = CallByName List.7 Test.0;
dec Test.0;
let Test.4 = CallByName Num.14 Test.5 Test.6;
ret Test.4;
"#
),
)
}
#[test]
fn list_get() {
compiles_to_ir(
indoc!(
r#"
main = \{} ->
List.get [1,2,3] 0
main {}
"#
),
indoc!(
r#"
procedure Test.0 (Test.2):
let Test.16 = 1i64;
let Test.17 = 2i64;
let Test.18 = 3i64;
let Test.6 = Array [Test.16, Test.17, Test.18];
let Test.7 = 0i64;
let Test.5 = CallByName List.3 Test.6 Test.7;
dec Test.6;
ret Test.5;
procedure List.3 (#Attr.2, #Attr.3):
let Test.15 = lowlevel ListLen #Attr.2;
let Test.11 = lowlevel NumLt #Attr.3 Test.15;
if Test.11 then
let Test.13 = 1i64;
let Test.14 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
let Test.12 = Ok Test.13 Test.14;
ret Test.12;
else
let Test.9 = 0i64;
let Test.10 = Struct {};
let Test.8 = Err Test.9 Test.10;
ret Test.8;
let Test.4 = Struct {};
let Test.3 = CallByName Test.0 Test.4;
ret Test.3;
"#
),
)
}
} }

View file

@ -0,0 +1,69 @@
app Quicksort provides [ quicksort ] imports []
quicksort : List Int -> List Int
quicksort = \originalList -> helper originalList
helper : List Int -> List Int
helper = \originalList ->
quicksortHelp : List (Num a), Int, Int -> List (Num a)
quicksortHelp = \list, low, high ->
if low < high then
when partition low high list is
Pair partitionIndex partitioned ->
partitioned
|> quicksortHelp low (partitionIndex - 1)
|> quicksortHelp (partitionIndex + 1) high
else
list
swap : Int, Int, List a -> List a
swap = \i, j, list ->
when Pair (List.get list i) (List.get list j) is
Pair (Ok atI) (Ok atJ) ->
list
|> List.set i atJ
|> List.set j atI
_ ->
[]
partition : Int, Int, List (Num a) -> [ Pair Int (List (Num a)) ]
partition = \low, high, initialList ->
when List.get initialList high is
Ok pivot ->
when partitionHelp (low - 1) low initialList high pivot is
Pair newI newList ->
Pair (newI + 1) (swap (newI + 1) high newList)
Err _ ->
Pair (low - 1) initialList
partitionHelp : Int, Int, List (Num a), Int, (Num a) -> [ Pair Int (List (Num a)) ]
partitionHelp = \i, j, list, high, pivot ->
if j < high then
when List.get list j is
Ok value ->
if value <= pivot then
partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) high pivot
else
partitionHelp i (j + 1) list high pivot
Err _ ->
Pair i list
else
Pair i list
result = quicksortHelp originalList 0 (List.len originalList - 1)
if List.len originalList > 3 then
result
else
# Absolutely make the `originalList` Shared by using it again here
# but this branch is not evaluated, so should not affect performance
List.set originalList 0 (List.len originalList)

View file

@ -0,0 +1,47 @@
use std::time::SystemTime;
#[link(name = "roc_app", kind = "static")]
extern "C" {
#[allow(improper_ctypes)]
#[link_name = "quicksort#1"]
fn quicksort(list: &[i64]) -> Box<[i64]>;
}
const NUM_NUMS: usize = 1_000_000;
pub fn main() {
let nums = {
let mut nums = Vec::with_capacity(NUM_NUMS + 1);
// give this list refcount 1
nums.push((std::usize::MAX - 1) as i64);
for index in 1..nums.capacity() {
let num = index as i64 % 12345;
nums.push(num);
}
nums
};
println!("Running Roc shared quicksort");
let start_time = SystemTime::now();
let answer = unsafe { quicksort(&nums[1..]) };
let end_time = SystemTime::now();
let duration = end_time.duration_since(start_time).unwrap();
println!(
"Roc quicksort took {:.4} ms to compute this answer: {:?}",
duration.as_secs_f64() * 1000.0,
// truncate the answer, so stdout is not swamped
// NOTE index 0 is the refcount!
&answer[1..20]
);
// the pointer is to the first _element_ of the list,
// but the refcount precedes it. Thus calling free() on
// this pointer would segfault/cause badness. Therefore, we
// leak it for now
Box::leak(answer);
}