diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 8f7b9d8081..c407d60ec3 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -31,7 +31,7 @@ use inkwell::OptimizationLevel; use inkwell::{AddressSpace, IntPredicate}; use roc_collections::all::{ImMap, MutSet}; use roc_module::low_level::LowLevel; -use roc_module::symbol::{Interns, Symbol}; +use roc_module::symbol::{Interns, ModuleId, Symbol}; use roc_mono::ir::{JoinPointId, Wrapped}; use roc_mono::layout::{Builtin, Layout, MemoryMode}; use target_lexicon::CallingConvention; @@ -53,6 +53,7 @@ pub enum OptLevel { #[derive(Default, Debug, Clone, PartialEq)] pub struct Scope<'a, 'ctx> { symbols: ImMap, PointerValue<'ctx>)>, + pub top_level_thunks: ImMap, FunctionValue<'ctx>)>, join_points: ImMap, &'a [PointerValue<'ctx>])>, } @@ -63,23 +64,23 @@ impl<'a, 'ctx> Scope<'a, 'ctx> { pub fn insert(&mut self, symbol: Symbol, value: (Layout<'a>, PointerValue<'ctx>)) { self.symbols.insert(symbol, value); } + pub fn insert_top_level_thunk( + &mut self, + symbol: Symbol, + layout: Layout<'a>, + function_value: FunctionValue<'ctx>, + ) { + self.top_level_thunks + .insert(symbol, (layout, function_value)); + } fn remove(&mut self, symbol: &Symbol) { self.symbols.remove(symbol); } - /* - fn get_join_point(&self, symbol: &JoinPointId) -> Option<&PhiValue<'ctx>> { - self.join_points.get(symbol) + + pub fn retain_top_level_thunks_for_module(&mut self, module_id: ModuleId) { + self.top_level_thunks + .retain(|s, _| s.module_id() == module_id); } - fn remove_join_point(&mut self, symbol: &JoinPointId) { - self.join_points.remove(symbol); - } - fn get_mut_join_point(&mut self, symbol: &JoinPointId) -> Option<&mut PhiValue<'ctx>> { - self.join_points.get_mut(symbol) - } - fn insert_join_point(&mut self, symbol: JoinPointId, value: PhiValue<'ctx>) { - self.join_points.insert(symbol, value); - } - */ } pub struct Env<'a, 'ctx, 'env> { @@ -1157,17 +1158,34 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( list_literal(env, inplace, scope, elem_layout, elems) } FunctionPointer(symbol, layout) => { - let fn_name = layout_ids - .get(*symbol, layout) - .to_symbol_string(*symbol, &env.interns); - let ptr = env - .module - .get_function(fn_name.as_str()) - .unwrap_or_else(|| panic!("Could not get pointer to unknown function {:?}", symbol)) - .as_global_value() - .as_pointer_value(); + match scope.top_level_thunks.get(symbol) { + Some((_layout, function_value)) => { + // this is a 0-argument thunk, evaluate it! + let call = env.builder.build_call( + function_value.clone(), + &[], + "evaluate_top_level_thunk", + ); - BasicValueEnum::PointerValue(ptr) + call.try_as_basic_value().left().unwrap() + } + None => { + // this is a function pointer, store it + let fn_name = layout_ids + .get(*symbol, layout) + .to_symbol_string(*symbol, &env.interns); + let ptr = env + .module + .get_function(fn_name.as_str()) + .unwrap_or_else(|| { + panic!("Could not get pointer to unknown function {:?}", symbol) + }) + .as_global_value() + .as_pointer_value(); + + BasicValueEnum::PointerValue(ptr) + } + } } RuntimeErrorFunction(_) => todo!(), } @@ -1847,6 +1865,7 @@ pub fn build_proc_header<'a, 'ctx, 'env>( pub fn build_proc<'a, 'ctx, 'env>( env: &'a Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mut scope: Scope<'a, 'ctx>, proc: roc_mono::ir::Proc<'a>, fn_val: FunctionValue<'ctx>, ) { @@ -1859,8 +1878,6 @@ pub fn build_proc<'a, 'ctx, 'env>( builder.position_at_end(entry); - let mut scope = Scope::default(); - // Add args to scope for (arg_val, (layout, arg_symbol)) in fn_val.get_param_iter().zip(args) { set_name(arg_val, arg_symbol.ident_string(&env.interns)); diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 6a53c93974..23d2d6de10 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -529,7 +529,6 @@ mod gen_primitives { } #[test] - #[ignore] fn top_level_constant() { assert_evals_to!( indoc!( diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index 5377918cbb..ced5d6ab94 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -25,7 +25,7 @@ pub fn helper<'a>( inkwell::execution_engine::ExecutionEngine<'a>, ) { use inkwell::OptimizationLevel; - use roc_gen::llvm::build::{build_proc, build_proc_header}; + use roc_gen::llvm::build::{build_proc, build_proc_header, Scope}; use std::path::{Path, PathBuf}; let stdlib_mode = stdlib.mode; @@ -141,15 +141,29 @@ pub fn helper<'a>( // Add all the Proc headers to the module. // We have to do this in a separate pass first, // because their bodies may reference each other. + let mut scope = Scope::default(); for ((symbol, layout), proc) in procedures.drain() { let fn_val = build_proc_header(&env, &mut layout_ids, symbol, &layout, &proc); + if proc.args.is_empty() { + // this is a 0-argument thunk, i.e. a top-level constant definition + // it must be in-scope everywhere in the module! + scope.insert_top_level_thunk(symbol, layout, fn_val); + } + headers.push((proc, fn_val)); } // Build each proc using its header info. for (proc, fn_val) in headers { - build_proc(&env, &mut layout_ids, proc, fn_val); + let mut current_scope = scope.clone(); + + // only have top-level thunks for this proc's module in scope + // this retain is not needed for correctness, but will cause less confusion when debugging + let home = proc.name.module_id(); + current_scope.retain_top_level_thunks_for_module(home); + + build_proc(&env, &mut layout_ids, scope.clone(), proc, fn_val); if fn_val.verify(true) { function_pass.run_on(&fn_val); diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 4676341203..4346c69c7b 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -203,25 +203,23 @@ impl<'a> Procs<'a> { for (key, in_prog_proc) in self.specialized.into_iter() { match in_prog_proc { InProgress => unreachable!("The procedure {:?} should have be done by now", key), - Done(proc) => { + Done(mut proc) => { + use self::SelfRecursive::*; + if let SelfRecursive(id) = proc.is_self_recursive { + proc.body = crate::tail_recursion::make_tail_recursive( + arena, + id, + proc.name, + proc.body.clone(), + proc.args, + ); + } + result.insert(key, proc); } } } - for (_, proc) in result.iter_mut() { - use self::SelfRecursive::*; - if let SelfRecursive(id) = proc.is_self_recursive { - proc.body = crate::tail_recursion::make_tail_recursive( - arena, - id, - proc.name, - proc.body.clone(), - proc.args, - ); - } - } - result }