From 3739f41cac210bfaaa4c435463889a8bc6dce10b Mon Sep 17 00:00:00 2001 From: Folkert Date: Sat, 15 May 2021 21:25:38 +0200 Subject: [PATCH] explicitly store and pass layout of a function passed to lowlevel --- compiler/gen/src/llvm/build.rs | 131 +++++++++++++++++++--------- compiler/gen/src/llvm/build_list.rs | 8 +- compiler/mono/src/ir.rs | 51 ++--------- 3 files changed, 103 insertions(+), 87 deletions(-) diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 7fc8e29bcf..3e51f665b8 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -131,7 +131,6 @@ impl<'ctx> Iterator for FunctionIterator<'ctx> { #[derive(Default, Debug, Clone, PartialEq)] pub struct Scope<'a, 'ctx> { symbols: ImMap, BasicValueEnum<'ctx>)>, - pub function_pointers: MutMap, FunctionValue<'ctx>)>, pub top_level_thunks: ImMap, FunctionValue<'ctx>)>, join_points: ImMap, &'a [PointerValue<'ctx>])>, } @@ -885,7 +884,16 @@ pub fn build_exp_call<'a, 'ctx, 'env>( CallType::LowLevel { op, opt_closure_layout, - } => run_low_level(env, layout_ids, scope, parent, layout, *op, arguments), + } => run_low_level( + env, + layout_ids, + scope, + parent, + layout, + *op, + *opt_closure_layout, + arguments, + ), CallType::Foreign { foreign_symbol, @@ -1702,10 +1710,6 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( ) }); - scope - .function_pointers - .insert(left_hand_side, (*layout, function_value)); - let ptr = function_value.as_global_value().as_pointer_value(); BasicValueEnum::PointerValue(ptr) @@ -3609,10 +3613,65 @@ fn run_low_level<'a, 'ctx, 'env>( parent: FunctionValue<'ctx>, layout: &Layout<'a>, op: LowLevel, + opt_closure_layout: Option>, args: &[Symbol], ) -> BasicValueEnum<'ctx> { use LowLevel::*; + // macros because functions cause lifetime issues related to the `env` or `layout_ids` + macro_rules! passed_function_at_index { + ($function_layout:expr, $index:expr) => {{ + let function_symbol = args[$index]; + + let fn_name = layout_ids + .get(function_symbol, &$function_layout) + .to_symbol_string(function_symbol, &env.interns); + + env.module + .get_function(fn_name.as_str()) + .unwrap_or_else(|| { + panic!( + "Could not get pointer to unknown function {:?} {:?}", + fn_name, $function_layout + ) + }) + }}; + } + + macro_rules! list_walk { + ($variant:expr) => {{ + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (default, default_layout) = load_symbol_and_layout(scope, &args[1]); + + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 2); + + let (closure, closure_layout) = load_symbol_and_layout(scope, &args[3]); + + match list_layout { + Layout::Builtin(Builtin::EmptyList) => default, + Layout::Builtin(Builtin::List(_, element_layout)) => { + crate::llvm::build_list::list_walk_generic( + env, + layout_ids, + parent, + list, + element_layout, + function, + function_layout, + closure, + *closure_layout, + default, + default_layout, + $variant, + ) + } + _ => unreachable!("invalid list layout"), + } + }}; + } + match op { StrConcat => { // Str.concat : Str, Str -> Str @@ -3759,7 +3818,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (function_layout, function) = scope.function_pointers[&args[1]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 1); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); @@ -3784,7 +3844,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list1, list1_layout) = load_symbol_and_layout(scope, &args[0]); let (list2, list2_layout) = load_symbol_and_layout(scope, &args[1]); - let (function_layout, function) = scope.function_pointers[&args[2]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 2); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[3]); match (list1_layout, list2_layout) { @@ -3815,7 +3876,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list2, list2_layout) = load_symbol_and_layout(scope, &args[1]); let (list3, list3_layout) = load_symbol_and_layout(scope, &args[2]); - let (function_layout, function) = scope.function_pointers[&args[3]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 3); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[4]); match (list1_layout, list2_layout, list3_layout) { @@ -3849,7 +3911,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (function_layout, function) = scope.function_pointers[&args[1]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 1); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); @@ -3874,7 +3937,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (function_layout, function) = scope.function_pointers[&args[1]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 1); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); @@ -3899,7 +3963,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (function_layout, function) = scope.function_pointers[&args[1]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 1); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); @@ -3931,7 +3996,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (function_layout, function) = scope.function_pointers[&args[1]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 1); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); @@ -3981,30 +4047,15 @@ fn run_low_level<'a, 'ctx, 'env>( list_range(env, *builtin, low.into_int_value(), high.into_int_value()) } - ListWalk => list_walk_help( - env, - layout_ids, - scope, - parent, - args, - crate::llvm::build_list::ListWalk::Walk, - ), - ListWalkUntil => list_walk_help( - env, - layout_ids, - scope, - parent, - args, - crate::llvm::build_list::ListWalk::WalkUntil, - ), - ListWalkBackwards => list_walk_help( - env, - layout_ids, - scope, - parent, - args, - crate::llvm::build_list::ListWalk::WalkBackwards, - ), + ListWalk => { + list_walk!(crate::llvm::build_list::ListWalk::Walk) + } + ListWalkUntil => { + list_walk!(crate::llvm::build_list::ListWalk::WalkUntil) + } + ListWalkBackwards => { + list_walk!(crate::llvm::build_list::ListWalk::WalkBackwards) + } ListAppend => { // List.append : List elem, elem -> List elem debug_assert_eq!(args.len(), 2); @@ -4043,7 +4094,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); - let (_, function) = scope.function_pointers[&args[1]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 1); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[2]); @@ -4499,7 +4551,8 @@ fn run_low_level<'a, 'ctx, 'env>( let (dict, dict_layout) = load_symbol_and_layout(scope, &args[0]); let (default, default_layout) = load_symbol_and_layout(scope, &args[1]); - let (function_layout, function) = scope.function_pointers[&args[2]]; + let function_layout = opt_closure_layout.unwrap(); + let function = passed_function_at_index!(function_layout, 2); let (closure, closure_layout) = load_symbol_and_layout(scope, &args[3]); match dict_layout { diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index f5108e0adf..8bad938cc9 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -446,11 +446,13 @@ pub enum ListWalk { } pub fn list_walk_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, + env: &'ctx Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, scope: &crate::llvm::build::Scope<'a, 'ctx>, parent: FunctionValue<'ctx>, args: &[roc_module::symbol::Symbol], + function: FunctionValue<'a>, + function_layout: Layout<'a>, variant: ListWalk, ) -> BasicValueEnum<'ctx> { use crate::llvm::build::load_symbol_and_layout; @@ -461,8 +463,6 @@ pub fn list_walk_help<'a, 'ctx, 'env>( let (default, default_layout) = load_symbol_and_layout(scope, &args[1]); - let (function_layout, function) = scope.function_pointers[&args[2]]; - let (closure, closure_layout) = load_symbol_and_layout(scope, &args[3]); match list_layout { @@ -485,7 +485,7 @@ pub fn list_walk_help<'a, 'ctx, 'env>( } } -fn list_walk_generic<'a, 'ctx, 'env>( +pub fn list_walk_generic<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, _parent: FunctionValue<'ctx>, diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index cc1ce1e35e..0c6b405ada 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -7691,26 +7691,16 @@ where Layout::Struct(_) => { let function_symbol = lambda_set.set[0].0; - let bound = env.unique_symbol(); - // build the call - let stmt = Stmt::Let( + Stmt::Let( assigned, Expr::Call(to_lowlevel_call( - bound, + function_symbol, closure_data_symbol, function_layout, )), return_layout, env.arena.alloc(hole), - ); - - // fix the layout; needs top-level signature - Stmt::Let( - bound, - Expr::FunctionPointer(function_symbol, function_layout), - function_layout, - env.arena.alloc(stmt), ) } Layout::Builtin(Builtin::Int1) => { @@ -7775,13 +7765,11 @@ where let hole = Stmt::Jump(join_point_id, env.arena.alloc([assigned])); - let bound = env.unique_symbol(); - // build the call let stmt = Stmt::Let( assigned, Expr::Call(to_lowlevel_call( - bound, + *function_symbol, closure_data_symbol, function_layout, )), @@ -7789,13 +7777,6 @@ where env.arena.alloc(hole), ); - let stmt = Stmt::Let( - bound, - Expr::FunctionPointer(*function_symbol, function_layout), - function_layout, - env.arena.alloc(stmt), - ); - branches.push((i as u64, BranchInfo::None, stmt)); } @@ -7844,25 +7825,16 @@ where let hole = Stmt::Jump(join_point_id, env.arena.alloc([assigned])); - let bound = env.unique_symbol(); - // build the call - let stmt = Stmt::Let( + Stmt::Let( assigned, Expr::Call(to_lowlevel_call( - bound, + function_symbol, closure_data_symbol, function_layout, )), return_layout, env.arena.alloc(hole), - ); - - Stmt::Let( - assigned, - Expr::FunctionPointer(function_symbol, function_layout), - function_layout, - env.arena.alloc(stmt), ) } @@ -8328,24 +8300,15 @@ where let hole = Stmt::Jump(join_point_id, env.arena.alloc([result_symbol])); - let bound = env.unique_symbol(); - // build the call - let stmt = Stmt::Let( + Stmt::Let( result_symbol, Expr::Call(to_lowlevel_call( - bound, + function_symbol, closure_data_symbol, function_layout, )), return_layout, env.arena.alloc(hole), - ); - - Stmt::Let( - bound, - Expr::FunctionPointer(function_symbol, function_layout), - function_layout, - env.arena.alloc(stmt), ) }