diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 0573b2a57e..dfb65a5f1b 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -237,6 +237,7 @@ pub fn build_expr<'a, 'ctx, 'env>( default_branch: (default_stores, default_expr), ret_layout, cond_layout, + cond_symbol: _, } => { let ret_type = basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); @@ -346,6 +347,7 @@ pub fn build_expr<'a, 'ctx, 'env>( .left() .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) } + LoadWithoutIncrement(_symbol) => todo!("implement load without increment"), Load(symbol) => load_symbol(env, scope, symbol), Str(str_literal) => { if str_literal.is_empty() { @@ -718,8 +720,10 @@ pub fn build_expr<'a, 'ctx, 'env>( } RunLowLevel(op, args) => run_low_level(env, layout_ids, scope, parent, *op, args), - IncBefore(_, expr) => build_expr(env, layout_ids, scope, parent, expr), DecAfter(_, expr) => build_expr(env, layout_ids, scope, parent, expr), + + Reuse(_, expr) => build_expr(env, layout_ids, scope, parent, expr), + Reset(_, expr) => build_expr(env, layout_ids, scope, parent, expr), } } diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index ad3e3fef73..3f820d86e1 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -1194,6 +1194,7 @@ fn decide_to_branching<'a>( Expr::Switch { cond: env.arena.alloc(cond), cond_layout, + cond_symbol, branches: branches.into_bump_slice(), default_branch, ret_layout, diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index b831c621f7..d637a376d9 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -288,6 +288,10 @@ pub enum Expr<'a> { Load(Symbol), Store(&'a [(Symbol, Layout<'a>, Expr<'a>)], &'a Expr<'a>), + /// RC instructions + LoadWithoutIncrement(Symbol), + DecAfter(Symbol, &'a Expr<'a>), + // Functions FunctionPointer(Symbol, Layout<'a>), RuntimeErrorFunction(&'a str), @@ -323,6 +327,7 @@ pub enum Expr<'a> { /// This *must* be an integer, because Switch potentially compiles to a jump table. cond: &'a Expr<'a>, cond_layout: Layout<'a>, + cond_symbol: Symbol, /// The u64 in the tuple will be compared directly to the condition Expr. /// If they are equal, this branch will be taken. branches: &'a [(u64, Stores<'a>, Expr<'a>)], @@ -352,13 +357,100 @@ pub enum Expr<'a> { }, EmptyArray, - /// RC instructions - IncBefore(Symbol, &'a Expr<'a>), - DecAfter(Symbol, &'a Expr<'a>), + /// Reset/Reuse + + /// Re-use Symbol in Expr + Reuse(Symbol, &'a Expr<'a>), + Reset(Symbol, &'a Expr<'a>), RuntimeError(&'a str), } +fn function_r<'a>(env: &mut Env<'a, '_>, body: &'a Expr<'a>) -> &'a Expr<'a> { + use Expr::*; + + match body { + Switch { + cond_symbol, + branches, + cond, + cond_layout, + default_branch, + ret_layout, + } => { + let stack_size = cond_layout.stack_size(env.pointer_size); + let mut new_branches = Vec::with_capacity_in(branches.len(), env.arena); + + for (tag, stores, branch) in branches.iter() { + let new_branch = function_d(env, *cond_symbol, stack_size as _, branch); + + new_branches.push((*tag, *stores, new_branch)); + } + + let new_default_branch = ( + default_branch.0, + &*env.arena.alloc(function_d( + env, + *cond_symbol, + stack_size as _, + default_branch.1, + )), + ); + + env.arena.alloc(Switch { + cond_symbol: *cond_symbol, + branches: new_branches.into_bump_slice(), + default_branch: new_default_branch, + ret_layout: ret_layout.clone(), + cond: *cond, + cond_layout: cond_layout.clone(), + }) + } + Store(stores, body) => { + let new_body = function_r(env, body); + + env.arena.alloc(Store(stores, new_body)) + } + + _ => body, + } +} + +fn function_d<'a>( + env: &mut Env<'a, '_>, + z: Symbol, + stack_size: usize, + body: &'a Expr<'a>, +) -> Expr<'a> { + if let Some(reused) = function_s(env, z, stack_size, body) { + Expr::Reset(z, env.arena.alloc(reused)) + } else { + body.clone() + } + /* + match body { + Expr::Tag { .. } => Some(env.arena.alloc(Expr::Reuse(w, body))), + _ => None, + } + */ +} + +fn function_s<'a>( + env: &mut Env<'a, '_>, + w: Symbol, + stack_size: usize, + body: &'a Expr<'a>, +) -> Option<&'a Expr<'a>> { + match body { + Expr::Tag { tag_layout, .. } + if tag_layout.stack_size(env.pointer_size) as usize <= stack_size => + { + Some(env.arena.alloc(Expr::Reuse(w, body))) + } + _ => None, + } +} + #[derive(Clone, Debug)] pub enum MonoProblem { PatternProblem(crate::pattern::Error), @@ -373,7 +465,8 @@ impl<'a> Expr<'a> { ) -> Self { let mut layout_cache = LayoutCache::default(); - from_can(env, can_expr, procs, &mut layout_cache) + let result = from_can(env, can_expr, procs, &mut layout_cache); + function_r(env, env.arena.alloc(result)).clone() } } @@ -570,6 +663,15 @@ fn pattern_to_when<'a>( } } +fn decrement_refcount<'a>(env: &mut Env<'a, '_>, symbol: Symbol, expr: Expr<'a>) -> Expr<'a> { + // TODO are there any builtins that should be refcounted? + if symbol.is_builtin() { + expr + } else { + Expr::DecAfter(symbol, env.arena.alloc(expr)) + } +} + #[allow(clippy::cognitive_complexity)] fn from_can<'a>( env: &mut Env<'a, '_>, @@ -604,7 +706,8 @@ fn from_can<'a>( layout_cache, ) } else { - Expr::IncBefore(symbol, env.arena.alloc(Expr::Load(symbol))) + // NOTE Load will always increment the refcount + Expr::Load(symbol) } } LetRec(defs, ret_expr, _, _) => from_can_defs(env, defs, *ret_expr, layout_cache, procs), @@ -614,7 +717,7 @@ fn from_can<'a>( // TODO is order important here? for symbol in symbols { - result = Expr::DecAfter(symbol, env.arena.alloc(result)); + result = decrement_refcount(env, symbol, result); } result @@ -699,17 +802,35 @@ fn from_can<'a>( region, loc_cond, branches, - } => from_can_when( - env, - cond_var, - expr_var, - region, - *loc_cond, - branches, - layout_cache, - procs, - ), + } => { + let cond_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_cond.value { + symbol + } else { + env.unique_symbol() + }; + let mono_when = from_can_when( + env, + cond_var, + expr_var, + region, + cond_symbol, + branches, + layout_cache, + procs, + ); + + let mono_cond = from_can(env, loc_cond.value, procs, layout_cache); + + let cond_layout = layout_cache + .from_var(env.arena, cond_var, env.subs, env.pointer_size) + .expect("invalid cond_layout"); + + Expr::Store( + env.arena.alloc([(cond_symbol, cond_layout, mono_cond)]), + env.arena.alloc(mono_when), + ) + } If { cond_var, branch_var, @@ -1203,7 +1324,7 @@ fn from_can_when<'a>( cond_var: Variable, expr_var: Variable, region: Region, - loc_cond: Located, + cond_symbol: Symbol, mut branches: std::vec::Vec, layout_cache: &mut LayoutCache<'a>, procs: &mut Procs<'a>, @@ -1260,9 +1381,6 @@ fn from_can_when<'a>( let cond_layout = layout_cache .from_var(env.arena, cond_var, env.subs, env.pointer_size) .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); - let cond_symbol = env.unique_symbol(); - let cond = from_can(env, loc_cond.value, procs, layout_cache); - stored.push((cond_symbol, cond_layout.clone(), cond)); // NOTE this will still store shadowed names. // that's fine: the branch throws a runtime error anyway @@ -1273,7 +1391,7 @@ fn from_can_when<'a>( }; for symbol in bound_symbols { - ret = Expr::DecAfter(symbol, env.arena.alloc(ret)); + ret = decrement_refcount(env, symbol, ret); } Expr::Store(stored.into_bump_slice(), arena.alloc(ret)) @@ -1282,9 +1400,6 @@ fn from_can_when<'a>( .from_var(env.arena, cond_var, env.subs, env.pointer_size) .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); - let cond = from_can(env, loc_cond.value, procs, layout_cache); - let cond_symbol = env.unique_symbol(); - let mut loc_branches = std::vec::Vec::new(); let mut opt_branches = std::vec::Vec::new(); @@ -1406,17 +1521,13 @@ fn from_can_when<'a>( .from_var(env.arena, expr_var, env.subs, env.pointer_size) .unwrap_or_else(|err| panic!("TODO turn this into a RuntimeError {:?}", err)); - let branching = crate::decision_tree::optimize_when( + crate::decision_tree::optimize_when( env, cond_symbol, cond_layout.clone(), ret_layout, opt_branches, - ); - - let stores = env.arena.alloc([(cond_symbol, cond_layout, cond)]); - - Expr::Store(stores, env.arena.alloc(branching)) + ) } } @@ -1640,7 +1751,7 @@ fn specialize<'a>( // TODO does order matter here? for &symbol in pattern_symbols.iter() { - specialized_body = Expr::DecAfter(symbol, env.arena.alloc(specialized_body)); + specialized_body = decrement_refcount(env, symbol, specialized_body); } // reset subs, so we don't get type errors when specializing for a different signature diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 19a65ac364..0ae4d2ed9e 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -11,6 +11,7 @@ mod helpers; mod test_mono { use crate::helpers::{can_expr, infer_expr, test_home, CanExprOut}; use bumpalo::Bump; + use roc_module::ident::TagName; use roc_module::symbol::{Interns, Symbol}; use roc_mono::expr::Expr::{self, *}; use roc_mono::expr::Procs; @@ -166,30 +167,33 @@ mod test_mono { let home = test_home(); let gen_symbol_0 = Interns::from_index(home, 0); - Struct(&[ - ( - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer( - &[Layout::Builtin(Builtin::Int64)], - &Layout::Builtin(Builtin::Int64), - ), - args: &[(Int(4), Layout::Builtin(Int64))], - }, - Layout::Builtin(Int64), - ), - ( - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer( - &[Layout::Builtin(Builtin::Float64)], - &Layout::Builtin(Builtin::Float64), - ), - args: &[(Float(3.14), Layout::Builtin(Float64))], - }, - Layout::Builtin(Float64), - ), - ]) + DecAfter( + gen_symbol_0, + &Struct(&[ + ( + CallByName { + name: gen_symbol_0, + layout: Layout::FunctionPointer( + &[Layout::Builtin(Builtin::Int64)], + &Layout::Builtin(Builtin::Int64), + ), + args: &[(Int(4), Layout::Builtin(Int64))], + }, + Layout::Builtin(Int64), + ), + ( + CallByName { + name: gen_symbol_0, + layout: Layout::FunctionPointer( + &[Layout::Builtin(Builtin::Float64)], + &Layout::Builtin(Builtin::Float64), + ), + args: &[(Float(3.14), Layout::Builtin(Float64))], + }, + Layout::Builtin(Float64), + ), + ]), + ) }, ) } @@ -299,27 +303,30 @@ mod test_mono { let gen_symbol_0 = Interns::from_index(home, 1); let symbol_x = Interns::from_index(home, 0); - Store( - &[( - symbol_x, - Builtin(Str), - Store( - &[( - gen_symbol_0, - Layout::Builtin(layout::Builtin::Int1), - Expr::Bool(true), - )], - &Cond { - cond_symbol: gen_symbol_0, - branch_symbol: gen_symbol_0, - cond_layout: Builtin(Int1), - pass: (&[] as &[_], &Expr::Str("bar")), - fail: (&[] as &[_], &Expr::Str("foo")), - ret_layout: Builtin(Str), - }, - ), - )], - &Load(symbol_x), + DecAfter( + symbol_x, + &Store( + &[( + symbol_x, + Builtin(Str), + Store( + &[( + gen_symbol_0, + Layout::Builtin(layout::Builtin::Int1), + Expr::Bool(true), + )], + &Cond { + cond_symbol: gen_symbol_0, + branch_symbol: gen_symbol_0, + cond_layout: Builtin(Int1), + pass: (&[] as &[_], &Expr::Str("bar")), + fail: (&[] as &[_], &Expr::Str("foo")), + ret_layout: Builtin(Str), + }, + ), + )], + &Load(symbol_x), + ), ) }, ) @@ -359,27 +366,30 @@ mod test_mono { let gen_symbol_0 = Interns::from_index(home, 0); let struct_layout = Layout::Struct(&[I64_LAYOUT, F64_LAYOUT]); - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer( - &[struct_layout.clone()], - &struct_layout.clone(), - ), - args: &[( - Struct(&[ - ( - CallByName { - name: gen_symbol_0, - layout: Layout::FunctionPointer(&[I64_LAYOUT], &I64_LAYOUT), - args: &[(Int(4), I64_LAYOUT)], - }, - I64_LAYOUT, - ), - (Float(0.1), F64_LAYOUT), - ]), - struct_layout, - )], - } + DecAfter( + gen_symbol_0, + &CallByName { + name: gen_symbol_0, + layout: Layout::FunctionPointer( + &[struct_layout.clone()], + &struct_layout.clone(), + ), + args: &[( + Struct(&[ + ( + CallByName { + name: gen_symbol_0, + layout: Layout::FunctionPointer(&[I64_LAYOUT], &I64_LAYOUT), + args: &[(Int(4), I64_LAYOUT)], + }, + I64_LAYOUT, + ), + (Float(0.1), F64_LAYOUT), + ]), + struct_layout, + )], + }, + ) }, ) } @@ -493,7 +503,9 @@ mod test_mono { let load = Load(var_x); - Store(arena.alloc(stores), arena.alloc(load)) + let store = Store(arena.alloc(stores), arena.alloc(load)); + + DecAfter(var_x, arena.alloc(store)) }, ); } @@ -517,7 +529,9 @@ mod test_mono { let load = Load(var_x); - Store(arena.alloc(stores), arena.alloc(load)) + let store = Store(arena.alloc(stores), arena.alloc(load)); + + DecAfter(var_x, arena.alloc(store)) }, ); } @@ -543,7 +557,9 @@ mod test_mono { let load = Load(var_x); - Store(arena.alloc(stores), arena.alloc(load)) + let store = Store(arena.alloc(stores), arena.alloc(load)); + + DecAfter(var_x, arena.alloc(store)) }, ); } @@ -589,33 +605,143 @@ mod test_mono { }); } - // #[test] - // fn when_on_result() { - // compiles_to( - // r#" - // when 1 is - // 1 -> 12 - // _ -> 34 - // "#, - // { - // use self::Builtin::*; - // use Layout::Builtin; - // let home = test_home(); + // #[test] + // fn when_on_result() { + // compiles_to( + // r#" + // when 1 is + // 1 -> 12 + // _ -> 34 + // "#, + // { + // use self::Builtin::*; + // use Layout::Builtin; + // let home = test_home(); // - // let gen_symbol_3 = Interns::from_index(home, 3); - // let gen_symbol_4 = Interns::from_index(home, 4); + // let gen_symbol_3 = Interns::from_index(home, 3); + // let gen_symbol_4 = Interns::from_index(home, 4); // - // CallByName( - // gen_symbol_3, - // &[( - // Struct(&[( - // CallByName(gen_symbol_4, &[(Int(4), Builtin(Int64))]), - // Builtin(Int64), - // )]), - // Layout::Struct(&[("x".into(), Builtin(Int64))]), - // )], - // ) - // }, - // ) - // } + // CallByName( + // gen_symbol_3, + // &[( + // Struct(&[( + // CallByName(gen_symbol_4, &[(Int(4), Builtin(Int64))]), + // Builtin(Int64), + // )]), + // Layout::Struct(&[("x".into(), Builtin(Int64))]), + // )], + // ) + // }, + // ) + // } + + #[test] + fn insert_reset_reuse() { + compiles_to( + r#" + when Foo 1 is + Foo _ -> Foo 1 + Bar -> Foo 2 + Baz -> Foo 2 + a -> a + "#, + { + use self::Builtin::*; + use Layout::{Builtin, Union}; + + let home = test_home(); + let gen_symbol_1 = Interns::from_index(home, 1); + + let union_layout = Union(&[ + &[Builtin(Int64)], + &[Builtin(Int64)], + &[Builtin(Int64), Builtin(Int64)], + ]); + + Store( + &[( + gen_symbol_1, + union_layout.clone(), + Tag { + tag_layout: union_layout.clone(), + tag_name: TagName::Global("Foo".into()), + tag_id: 2, + union_size: 3, + arguments: &[(Int(2), Builtin(Int64)), (Int(1), Builtin(Int64))], + }, + )], + &Store( + &[], + &Switch { + cond: &Load(gen_symbol_1), + cond_symbol: gen_symbol_1, + branches: &[ + ( + 2, + &[], + Reset( + gen_symbol_1, + &Reuse( + gen_symbol_1, + &Tag { + tag_layout: union_layout.clone(), + tag_name: TagName::Global("Foo".into()), + tag_id: 2, + union_size: 3, + arguments: &[ + (Int(2), Builtin(Int64)), + (Int(1), Builtin(Int64)), + ], + }, + ), + ), + ), + ( + 0, + &[], + Reset( + gen_symbol_1, + &Reuse( + gen_symbol_1, + &Tag { + tag_layout: union_layout.clone(), + tag_name: TagName::Global("Foo".into()), + tag_id: 2, + union_size: 3, + arguments: &[ + (Int(2), Builtin(Int64)), + (Int(2), Builtin(Int64)), + ], + }, + ), + ), + ), + ], + default_branch: ( + &[], + &Reset( + gen_symbol_1, + &Reuse( + gen_symbol_1, + &Tag { + tag_layout: union_layout.clone(), + tag_name: TagName::Global("Foo".into()), + tag_id: 2, + union_size: 3, + arguments: &[ + (Int(2), Builtin(Int64)), + (Int(2), Builtin(Int64)), + ], + }, + ), + ), + ), + ret_layout: union_layout.clone(), + cond_layout: union_layout, + }, + ), + ) + }, + ) + } }