diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index d4c56854a2..dfa0c3c281 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -362,289 +362,24 @@ pub fn construct_optimization_passes<'a>( (mpm, fpm) } -/// For communication with C (tests and platforms) we need to abide by the C calling convention -/// -/// While small values are just returned like with the fast CC, larger structures need to -/// be written into a pointer (into the callers stack) -enum PassVia { - Register, - Memory, -} - -impl PassVia { - fn from_layout(ptr_bytes: u32, layout: &Layout<'_>) -> Self { - let stack_size = layout.stack_size(ptr_bytes); - let eightbyte = 8; - - if stack_size > 2 * eightbyte { - PassVia::Memory - } else { - PassVia::Register - } - } -} - -/// entry point to roc code; uses the fastcc calling convention -pub fn build_roc_main<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - main_body: &roc_mono::ir::Stmt<'a>, -) -> &'a FunctionValue<'ctx> { - use inkwell::types::BasicType; - - let context = env.context; - let builder = env.builder; - let arena = env.arena; - let ptr_bytes = env.ptr_bytes; - - let return_type = basic_type_from_layout(&arena, context, &layout, ptr_bytes); - let roc_main_fn_name = "$Test.roc_main"; - - // make the roc main function - let roc_main_fn_type = return_type.fn_type(&[], false); - - // Add main to the module. - let roc_main_fn = env - .module - .add_function(roc_main_fn_name, roc_main_fn_type, None); - - // internal function, use fast calling convention - roc_main_fn.set_call_conventions(FAST_CALL_CONV); - - // Add main's body - let basic_block = context.append_basic_block(roc_main_fn, "entry"); - - builder.position_at_end(basic_block); - - // builds the function body (return statement included) - build_exp_stmt( - env, - layout_ids, - &mut Scope::default(), - roc_main_fn, - main_body, - ); - - env.arena.alloc(roc_main_fn) -} - pub fn promote_to_main_function<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, symbol: Symbol, layout: &Layout<'a>, -) -> (&'static str, &'a FunctionValue<'ctx>) { +) -> (&'static str, FunctionValue<'ctx>) { let fn_name = layout_ids .get(symbol, layout) .to_symbol_string(symbol, &env.interns); - let wrapped = env.module.get_function(&fn_name).unwrap(); - - make_main_function_help(env, layout, wrapped) -} - -pub fn make_main_function<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - main_body: &roc_mono::ir::Stmt<'a>, -) -> (&'static str, &'a FunctionValue<'ctx>) { - // internal main function - let roc_main_fn = *build_roc_main(env, layout_ids, layout, main_body); - - make_main_function_help(env, layout, roc_main_fn) -} - -fn make_main_function_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout: &Layout<'a>, - roc_main_fn: FunctionValue<'ctx>, -) -> (&'static str, &'a FunctionValue<'ctx>) { - // build the C calling convention wrapper - use inkwell::types::BasicType; - use PassVia::*; - - let context = env.context; - let builder = env.builder; + let roc_main_fn = env.module.get_function(&fn_name).unwrap(); let main_fn_name = "$Test.main"; - let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); - - let fields = [Layout::Builtin(Builtin::Int64), layout.clone()]; - let main_return_layout = Layout::Struct(&fields); - let main_return_type = block_of_memory(context, &main_return_layout, env.ptr_bytes); - - let register_or_memory = PassVia::from_layout(env.ptr_bytes, &main_return_layout); - - let main_fn_type = match register_or_memory { - Memory => { - let return_value_ptr = context.i64_type().ptr_type(AddressSpace::Generic).into(); - context.void_type().fn_type(&[return_value_ptr], false) - } - Register => main_return_type.fn_type(&[], false), - }; // Add main to the module. - let main_fn = env.module.add_function(main_fn_name, main_fn_type, None); + let main_fn = expose_function_to_host_help(env, roc_main_fn, main_fn_name); - // our exposed main function adheres to the C calling convention - main_fn.set_call_conventions(C_CALL_CONV); - - // Add main's body - let basic_block = context.append_basic_block(main_fn, "entry"); - let then_block = context.append_basic_block(main_fn, "then_block"); - let catch_block = context.append_basic_block(main_fn, "catch_block"); - let cont_block = context.append_basic_block(main_fn, "cont_block"); - - builder.position_at_end(basic_block); - - let result_alloca = builder.build_alloca(main_return_type, "result"); - - // invoke instead of call, so that we can catch any exeptions thrown in Roc code - let call_result = { - let call = builder.build_invoke(roc_main_fn, &[], then_block, catch_block, "call_roc_main"); - call.set_call_convention(FAST_CALL_CONV); - call.try_as_basic_value().left().unwrap() - }; - - // exception handling - { - builder.position_at_end(catch_block); - - let landing_pad_type = { - let exception_ptr = context.i8_type().ptr_type(AddressSpace::Generic).into(); - let selector_value = context.i32_type().into(); - - context.struct_type(&[exception_ptr, selector_value], false) - }; - - let info = builder - .build_catch_all_landing_pad( - &landing_pad_type, - &BasicValueEnum::IntValue(context.i8_type().const_zero()), - context.i8_type().ptr_type(AddressSpace::Generic), - "main_landing_pad", - ) - .into_struct_value(); - - let exception_ptr = builder - .build_extract_value(info, 0, "exception_ptr") - .unwrap(); - - let thrown = cxa_begin_catch(env, exception_ptr); - - let error_msg = { - let exception_type = u8_ptr; - let ptr = builder.build_bitcast( - thrown, - exception_type.ptr_type(AddressSpace::Generic), - "cast", - ); - - builder.build_load(ptr.into_pointer_value(), "error_msg") - }; - - let return_type = context.struct_type(&[context.i64_type().into(), u8_ptr.into()], false); - - let return_value = { - let v1 = return_type.const_zero(); - - // flag is non-zero, indicating failure - let flag = context.i64_type().const_int(1, false); - - let v2 = builder - .build_insert_value(v1, flag, 0, "set_error") - .unwrap(); - - let v3 = builder - .build_insert_value(v2, error_msg, 1, "set_exception") - .unwrap(); - - v3 - }; - - // bitcast result alloca so we can store our concrete type { flag, error_msg } in there - let result_alloca_bitcast = builder - .build_bitcast( - result_alloca, - return_type.ptr_type(AddressSpace::Generic), - "result_alloca_bitcast", - ) - .into_pointer_value(); - - // store our return value - builder.build_store(result_alloca_bitcast, return_value); - - cxa_end_catch(env); - - builder.build_unconditional_branch(cont_block); - } - - { - builder.position_at_end(then_block); - - let actual_return_type = - basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes); - let return_type = - context.struct_type(&[context.i64_type().into(), actual_return_type], false); - - let return_value = { - let v1 = return_type.const_zero(); - - let v2 = builder - .build_insert_value(v1, context.i64_type().const_zero(), 0, "set_no_error") - .unwrap(); - let v3 = builder - .build_insert_value(v2, call_result, 1, "set_call_result") - .unwrap(); - - v3 - }; - - let ptr = builder.build_bitcast( - result_alloca, - return_type.ptr_type(AddressSpace::Generic), - "name", - ); - builder.build_store(ptr.into_pointer_value(), return_value); - - builder.build_unconditional_branch(cont_block); - } - - { - builder.position_at_end(cont_block); - - let result = builder.build_load(result_alloca, "result"); - - match register_or_memory { - Memory => { - // write the result into the supplied pointer - let ptr_return_type = main_return_type.ptr_type(AddressSpace::Generic); - - let ptr_as_int = main_fn.get_first_param().unwrap(); - - let ptr = builder.build_bitcast(ptr_as_int, ptr_return_type, "caller_ptr"); - - builder.build_store(ptr.into_pointer_value(), result); - - // this is a void function, therefore return None - builder.build_return(None); - } - Register => { - // construct a normal return - // values are passed to the caller via registers - builder.build_return(Some(&result)); - } - } - } - - // MUST set the personality at the very end; - // doing it earlier can cause the personality to be ignored - let personality_func = get_gxx_personality_v0(env); - main_fn.set_personality_function(personality_func); - - (main_fn_name, env.arena.alloc(main_fn)) + (main_fn_name, main_fn) } fn get_inplace_from_layout(layout: &Layout<'_>) -> InPlace { @@ -1875,9 +1610,19 @@ fn expose_function_to_host<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, roc_function: FunctionValue<'ctx>, ) { + let c_function_name: String = format!("{}_exposed", roc_function.get_name().to_str().unwrap()); + + expose_function_to_host_help(env, roc_function, &c_function_name); +} + +fn expose_function_to_host_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + roc_function: FunctionValue<'ctx>, + c_function_name: &str, +) -> FunctionValue<'ctx> { use inkwell::types::BasicType; - let roc_wrapper_function = make_exception_catching_wrapper(env, roc_function); + let roc_wrapper_function = make_exception_catcher(env, roc_function); let roc_function_type = roc_wrapper_function.get_type(); @@ -1888,13 +1633,10 @@ fn expose_function_to_host<'a, 'ctx, 'env>( argument_types.push(output_type.into()); let c_function_type = env.context.void_type().fn_type(&argument_types, false); - let c_function_name: String = format!("{}_exposed", roc_function.get_name().to_str().unwrap()); - let c_function = env.module.add_function( - c_function_name.as_str(), - c_function_type, - Some(Linkage::External), - ); + let c_function = + env.module + .add_function(c_function_name, c_function_type, Some(Linkage::External)); // STEP 2: build the exposed function's body let builder = env.builder; @@ -1942,6 +1684,8 @@ fn expose_function_to_host<'a, 'ctx, 'env>( let size: BasicValueEnum = return_type.size_of().unwrap().into(); builder.build_return(Some(&size)); + + c_function } fn invoke_and_catch<'a, 'ctx, 'env, F, T>( @@ -2095,9 +1839,19 @@ where result } +fn make_exception_catcher<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + roc_function: FunctionValue<'ctx>, +) -> FunctionValue<'ctx> { + let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap()); + + make_exception_catching_wrapper(env, roc_function, &wrapper_function_name) +} + fn make_exception_catching_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, roc_function: FunctionValue<'ctx>, + wrapper_function_name: &str, ) -> FunctionValue<'ctx> { // build the C calling convention wrapper @@ -2107,8 +1861,6 @@ fn make_exception_catching_wrapper<'a, 'ctx, 'env>( let roc_function_type = roc_function.get_type(); let argument_types = roc_function_type.get_param_types(); - let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap()); - let wrapper_return_type = context.struct_type( &[ context.i64_type().into(), @@ -2328,70 +2080,6 @@ pub fn build_closure_caller<'a, 'ctx, 'env>( builder.build_return(Some(&size)); } -#[allow(dead_code)] -pub fn build_closure_caller_old<'a, 'ctx, 'env>( - env: &'a Env<'a, 'ctx, 'env>, - closure_function: FunctionValue<'ctx>, -) { - let context = env.context; - let builder = env.builder; - // asuming the closure has type `a, b, closure_data -> c` - // change that into `a, b, *const closure_data, *mut output -> ()` - - // a function `a, b, closure_data -> RocCallResult` - let wrapped_function = make_exception_catching_wrapper(env, closure_function); - - let closure_function_type = closure_function.get_type(); - let wrapped_function_type = wrapped_function.get_type(); - - let mut arguments = closure_function_type.get_param_types(); - - // require that the closure data is passed by reference - let closure_data_type = arguments.pop().unwrap(); - let closure_data_ptr_type = get_ptr_type(&closure_data_type, AddressSpace::Generic); - arguments.push(closure_data_ptr_type.into()); - - // require that a pointer is passed in to write the result into - let output_type = get_ptr_type( - &wrapped_function_type.get_return_type().unwrap(), - AddressSpace::Generic, - ); - arguments.push(output_type.into()); - - let caller_function_type = env.context.void_type().fn_type(&arguments, false); - let caller_function_name: String = - format!("{}_caller", closure_function.get_name().to_str().unwrap()); - - let caller_function = env.module.add_function( - caller_function_name.as_str(), - caller_function_type, - Some(Linkage::External), - ); - - caller_function.set_call_conventions(C_CALL_CONV); - - let entry = context.append_basic_block(caller_function, "entry"); - - builder.position_at_end(entry); - - let mut parameters = caller_function.get_params(); - let output = parameters.pop().unwrap(); - let closure_data_ptr = parameters.pop().unwrap(); - - let closure_data = - builder.build_load(closure_data_ptr.into_pointer_value(), "load_closure_data"); - parameters.push(closure_data); - - let call = builder.build_call(wrapped_function, ¶meters, "call_wrapped_function"); - call.set_call_convention(FAST_CALL_CONV); - - let result = call.try_as_basic_value().left().unwrap(); - - builder.build_store(output.into_pointer_value(), result); - - builder.build_return(None); -} - pub fn build_proc<'a, 'ctx, 'env>( env: &'a Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, diff --git a/compiler/gen/src/run_roc.rs b/compiler/gen/src/run_roc.rs index c9b0aff8d3..024a5198ca 100644 --- a/compiler/gen/src/run_roc.rs +++ b/compiler/gen/src/run_roc.rs @@ -36,15 +36,20 @@ macro_rules! run_jit_function { ($lib: expr, $main_fn_name: expr, $ty:ty, $transform:expr, $errors:expr) => {{ use inkwell::context::Context; use roc_gen::run_roc::RocCallResult; + use std::mem::MaybeUninit; unsafe { - let main: libloading::Symbol RocCallResult<$ty>> = $lib - .get($main_fn_name.as_bytes()) - .ok() - .ok_or(format!("Unable to JIT compile `{}`", $main_fn_name)) - .expect("errored"); + let main: libloading::Symbol) -> ()> = + $lib.get($main_fn_name.as_bytes()) + .ok() + .ok_or(format!("Unable to JIT compile `{}`", $main_fn_name)) + .expect("errored"); - match main().into() { + let mut result = MaybeUninit::uninit(); + + main(result.as_mut_ptr()); + + match result.assume_init().into() { Ok(success) => { // only if there are no exceptions thrown, check for errors assert_eq!( diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 710e2e6b01..b0dacec440 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -1064,4 +1064,51 @@ mod gen_primitives { f64 ); } + + #[test] + fn return_wrapped_function_pointer() { + assert_non_opt_evals_to!( + indoc!( + r#" + app Test provides [ main ] imports [] + + Effect a : [ @Effect ({} -> a) ] + + foo : Effect {} + foo = @Effect \{} -> {} + + main : Effect {} + main = foo + "# + ), + 1, + i64, + |_| 1 + ); + } + + #[test] + fn return_wrapped_closure() { + assert_non_opt_evals_to!( + indoc!( + r#" + app Test provides [ main ] imports [] + + Effect a : [ @Effect ({} -> a) ] + + foo : Effect {} + foo = + x = 5 + + @Effect (\{} -> if x > 3 then {} else {}) + + main : Effect {} + main = foo + "# + ), + 1, + i64, + |_| 1 + ); + } } diff --git a/compiler/load/src/file.rs b/compiler/load/src/file.rs index 0994e953c1..caefce7b3c 100644 --- a/compiler/load/src/file.rs +++ b/compiler/load/src/file.rs @@ -1613,7 +1613,7 @@ fn update<'a>( // state.timings.insert(module_id, module_timing); // display the mono IR of the module, for debug purposes - if false { + if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS { let procs_string = state .procedures .values() diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 3b2d00ca5b..ec9aaf3547 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -14,6 +14,8 @@ use roc_types::subs::{Content, FlatType, Subs, Variable}; use std::collections::HashMap; use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder}; +pub const PRETTY_PRINT_IR_SYMBOLS: bool = false; + #[derive(Clone, Debug, PartialEq)] pub enum MonoProblem { PatternProblem(crate::exhaustive::Error), @@ -902,8 +904,11 @@ where D::Doc: Clone, A: Clone, { - alloc.text(format!("{}", symbol)) - // alloc.text(format!("{:?}", symbol)) + if PRETTY_PRINT_IR_SYMBOLS { + alloc.text(format!("{:?}", symbol)) + } else { + alloc.text(format!("{}", symbol)) + } } fn join_point_to_doc<'b, D, A>(alloc: &'b D, symbol: JoinPointId) -> DocBuilder<'b, D, A> @@ -1514,51 +1519,9 @@ fn specialize_external<'a>( pattern_symbols }; - let (proc_args, opt_closure_layout, ret_layout) = + let specialized = build_specialized_proc_from_var(env, layout_cache, proc_name, pattern_symbols, fn_var)?; - let mut specialized_body = from_can(env, fn_var, body, procs, layout_cache); - // unpack the closure symbols, if any - if let CapturedSymbols::Captured(captured) = captured_symbols { - let mut layouts = Vec::with_capacity_in(captured.len(), env.arena); - - for (_, variable) in captured.iter() { - let layout = layout_cache.from_var(env.arena, *variable, env.subs)?; - layouts.push(layout); - } - - let field_layouts = layouts.into_bump_slice(); - - let wrapped = match &opt_closure_layout { - Some(x) => x.get_wrapped(), - None => unreachable!("symbols are captured, so this must be a closure"), - }; - - for (index, (symbol, variable)) in captured.iter().enumerate() { - // layout is cached anyway, re-using the one found above leads to - // issues (combining by-ref and by-move in pattern match - let layout = layout_cache.from_var(env.arena, *variable, env.subs)?; - - // if the symbol has a layout that is dropped from data structures (e.g. `{}`) - // then regenerate the symbol here. The value may not be present in the closure - // data struct - let expr = { - if layout.is_dropped_because_empty() { - Expr::Struct(&[]) - } else { - Expr::AccessAtIndex { - index: index as _, - field_layouts, - structure: Symbol::ARG_CLOSURE, - wrapped, - } - } - }; - - specialized_body = Stmt::Let(*symbol, expr, layout, env.arena.alloc(specialized_body)); - } - } - // determine the layout of aliases/rigids exposed to the host let host_exposed_layouts = if host_exposed_variables.is_empty() { HostExposedLayouts::NotHostExposed @@ -1578,39 +1541,141 @@ fn specialize_external<'a>( } }; - // reset subs, so we don't get type errors when specializing for a different signature - layout_cache.rollback_to(cache_snapshot); - env.subs.rollback_to(snapshot); - let recursivity = if is_self_recursive { SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol())) } else { SelfRecursive::NotSelfRecursive }; - let closure_data_layout = match opt_closure_layout { - Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)), - None => None, - }; + let mut specialized_body = from_can(env, fn_var, body, procs, layout_cache); - let proc = Proc { - name: proc_name, - args: proc_args, - body: specialized_body, - closure_data_layout, - ret_layout, - is_self_recursive: recursivity, - host_exposed_layouts, - }; + match specialized { + SpecializedLayout::FunctionPointerBody { + arguments, + ret_layout, + closure: opt_closure_layout, + } => { + // this is a function body like + // + // foo = Num.add + // + // we need to expand this to + // + // foo = \x,y -> Num.add x y - Ok(proc) + // reset subs, so we don't get type errors when specializing for a different signature + layout_cache.rollback_to(cache_snapshot); + env.subs.rollback_to(snapshot); + + let closure_data_layout = match opt_closure_layout { + Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)), + None => None, + }; + + // I'm not sure how to handle the closure case, does it ever occur? + debug_assert_eq!(closure_data_layout, None); + debug_assert!(matches!(captured_symbols, CapturedSymbols::None)); + + // this will be a thunk returning a function, so its ret_layout must be a function! + let full_layout = Layout::FunctionPointer(arguments, env.arena.alloc(ret_layout)); + + let proc = Proc { + name: proc_name, + args: &[], + body: specialized_body, + closure_data_layout, + ret_layout: full_layout, + is_self_recursive: recursivity, + host_exposed_layouts, + }; + + Ok(proc) + } + SpecializedLayout::FunctionBody { + arguments: proc_args, + closure: opt_closure_layout, + ret_layout, + } => { + // unpack the closure symbols, if any + if let CapturedSymbols::Captured(captured) = captured_symbols { + let mut layouts = Vec::with_capacity_in(captured.len(), env.arena); + + for (_, variable) in captured.iter() { + let layout = layout_cache.from_var(env.arena, *variable, env.subs)?; + layouts.push(layout); + } + + let field_layouts = layouts.into_bump_slice(); + + let wrapped = match &opt_closure_layout { + Some(x) => x.get_wrapped(), + None => unreachable!("symbols are captured, so this must be a closure"), + }; + + for (index, (symbol, variable)) in captured.iter().enumerate() { + // layout is cached anyway, re-using the one found above leads to + // issues (combining by-ref and by-move in pattern match + let layout = layout_cache.from_var(env.arena, *variable, env.subs)?; + + // if the symbol has a layout that is dropped from data structures (e.g. `{}`) + // then regenerate the symbol here. The value may not be present in the closure + // data struct + let expr = { + if layout.is_dropped_because_empty() { + Expr::Struct(&[]) + } else { + Expr::AccessAtIndex { + index: index as _, + field_layouts, + structure: Symbol::ARG_CLOSURE, + wrapped, + } + } + }; + + specialized_body = + Stmt::Let(*symbol, expr, layout, env.arena.alloc(specialized_body)); + } + } + + // reset subs, so we don't get type errors when specializing for a different signature + layout_cache.rollback_to(cache_snapshot); + env.subs.rollback_to(snapshot); + + let closure_data_layout = match opt_closure_layout { + Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)), + None => None, + }; + + let proc = Proc { + name: proc_name, + args: proc_args, + body: specialized_body, + closure_data_layout, + ret_layout, + is_self_recursive: recursivity, + host_exposed_layouts, + }; + + Ok(proc) + } + } } -type SpecializedLayout<'a> = ( - &'a [(Layout<'a>, Symbol)], - Option>, - Layout<'a>, -); +enum SpecializedLayout<'a> { + /// A body like `foo = \a,b,c -> ...` + FunctionBody { + arguments: &'a [(Layout<'a>, Symbol)], + closure: Option>, + ret_layout: Layout<'a>, + }, + /// A body like `foo = Num.add` + FunctionPointerBody { + arguments: &'a [Layout<'a>], + closure: Option>, + ret_layout: Layout<'a>, + }, +} #[allow(clippy::type_complexity)] fn build_specialized_proc_from_var<'a>( @@ -1739,6 +1804,7 @@ fn build_specialized_proc<'a>( let mut proc_args = Vec::with_capacity_in(pattern_layouts.len(), arena); let pattern_layouts_len = pattern_layouts.len(); + let pattern_layouts_slice = pattern_layouts.clone().into_bump_slice(); for (arg_layout, arg_name) in pattern_layouts.into_iter().zip(pattern_symbols.iter()) { proc_args.push((arg_layout, *arg_name)); @@ -1761,6 +1827,7 @@ fn build_specialized_proc<'a>( // f_closure = { ptr: f, closure: x } // // then + use SpecializedLayout::*; match opt_closure_layout { Some(layout) if pattern_symbols.last() == Some(&Symbol::ARG_CLOSURE) => { // here we define the lifted (now top-level) f function. Its final argument is `Symbol::ARG_CLOSURE`, @@ -1776,7 +1843,11 @@ fn build_specialized_proc<'a>( let proc_args = proc_args.into_bump_slice(); - Ok((proc_args, Some(layout), ret_layout)) + Ok(FunctionBody { + arguments: proc_args, + closure: Some(layout), + ret_layout, + }) } Some(layout) => { // else if there is a closure layout, we're building the `f_closure` value @@ -1791,7 +1862,11 @@ fn build_specialized_proc<'a>( let closure_layout = Layout::Struct(arena.alloc([function_ptr_layout, closure_data_layout])); - Ok((&[], None, closure_layout)) + Ok(FunctionBody { + arguments: &[], + closure: None, + ret_layout: closure_layout, + }) } None => { // else we're making a normal function, no closure problems to worry about @@ -1805,16 +1880,32 @@ fn build_specialized_proc<'a>( Ordering::Equal => { let proc_args = proc_args.into_bump_slice(); - Ok((proc_args, None, ret_layout)) + Ok(FunctionBody { + arguments: proc_args, + closure: None, + ret_layout, + }) } Ordering::Greater => { - // so far, the problem when hitting this branch was always somewhere else - // I think this branch should not be reachable in a bugfree compiler - panic!("more arguments (according to the layout) than argument symbols") - } - Ordering::Less => { - panic!("more argument symbols than arguments (according to the layout)") + if pattern_symbols.is_empty() { + Ok(FunctionPointerBody { + arguments: pattern_layouts_slice, + closure: None, + ret_layout, + }) + } else { + // so far, the problem when hitting this branch was always somewhere else + // I think this branch should not be reachable in a bugfree compiler + panic!( + "more arguments (according to the layout) than argument symbols for {:?}", + proc_name + ) + } } + Ordering::Less => panic!( + "more argument symbols than arguments (according to the layout) for {:?}", + proc_name + ), } } } @@ -2213,8 +2304,22 @@ pub fn with_hole<'a>( } else if symbol.module_id() != env.home && symbol.module_id() != ModuleId::ATTR { match layout_cache.from_var(env.arena, variable, env.subs) { Err(e) => panic!("invalid layout {:?}", e), - Ok(Layout::FunctionPointer(_, _)) => { + Ok(layout @ Layout::FunctionPointer(_, _)) => { add_needed_external(procs, env, variable, symbol); + + match hole { + Stmt::Jump(_, _) => todo!("not sure what to do in this case yet"), + _ => { + let expr = Expr::FunctionPointer(symbol, layout.clone()); + let new_symbol = env.unique_symbol(); + return Stmt::Let( + new_symbol, + expr, + layout, + env.arena.alloc(Stmt::Ret(new_symbol)), + ); + } + } } Ok(_) => { // this is a 0-arity thunk @@ -4678,6 +4783,12 @@ fn call_by_name<'a>( .specialized .contains_key(&(proc_name, full_layout.clone())) { + debug_assert_eq!( + arg_layouts.len(), + field_symbols.len(), + "see call_by_name for background (scroll down a bit)" + ); + let call = Expr::FunctionCall { call_type: CallType::ByName(proc_name), ret_layout: ret_layout.clone(), @@ -4721,6 +4832,11 @@ fn call_by_name<'a>( ); } + debug_assert_eq!( + arg_layouts.len(), + field_symbols.len(), + "see call_by_name for background (scroll down a bit)" + ); let call = Expr::FunctionCall { call_type: CallType::ByName(proc_name), ret_layout: ret_layout.clone(), @@ -4762,30 +4878,102 @@ fn call_by_name<'a>( let function_layout = FunctionLayouts::from_layout(env.arena, layout); - procs.specialized.remove(&(proc_name, full_layout)); + procs.specialized.remove(&(proc_name, full_layout.clone())); procs.specialized.insert( (proc_name, function_layout.full.clone()), Done(proc), ); - let call = Expr::FunctionCall { - call_type: CallType::ByName(proc_name), - ret_layout: function_layout.result.clone(), - full_layout: function_layout.full, - arg_layouts: function_layout.arguments, - args: field_symbols, - }; + if field_symbols.is_empty() { + debug_assert!(loc_args.is_empty()); - let iter = loc_args - .into_iter() - .rev() - .zip(field_symbols.iter().rev()); + // This happens when we return a function, e.g. + // + // foo = Num.add + // + // Even though the layout (and type) are functions, + // there are no arguments. This confuses our IR, + // and we have to fix it here. + match full_layout { + Layout::Closure(_, closure_layout, _) => { + let call = Expr::FunctionCall { + call_type: CallType::ByName(proc_name), + ret_layout: function_layout.result.clone(), + full_layout: function_layout.full.clone(), + arg_layouts: function_layout.arguments, + args: field_symbols, + }; - let result = - Stmt::Let(assigned, call, function_layout.result, hole); + // in the case of a closure specifically, we + // have to create a custom layout, to make sure + // the closure data is part of the layout + let closure_struct_layout = Layout::Struct( + env.arena.alloc([ + function_layout.full, + closure_layout + .as_block_of_memory_layout(), + ]), + ); - assign_to_symbols(env, procs, layout_cache, iter, result) + Stmt::Let( + assigned, + call, + closure_struct_layout, + hole, + ) + } + _ => { + let call = Expr::FunctionCall { + call_type: CallType::ByName(proc_name), + ret_layout: function_layout.result.clone(), + full_layout: function_layout.full.clone(), + arg_layouts: function_layout.arguments, + args: field_symbols, + }; + + Stmt::Let( + assigned, + call, + function_layout.full, + hole, + ) + } + } + } else { + debug_assert_eq!( + function_layout.arguments.len(), + field_symbols.len(), + "scroll up a bit for background" + ); + let call = Expr::FunctionCall { + call_type: CallType::ByName(proc_name), + ret_layout: function_layout.result.clone(), + full_layout: function_layout.full, + arg_layouts: function_layout.arguments, + args: field_symbols, + }; + + let iter = loc_args + .into_iter() + .rev() + .zip(field_symbols.iter().rev()); + + let result = Stmt::Let( + assigned, + call, + function_layout.result, + hole, + ); + + assign_to_symbols( + env, + procs, + layout_cache, + iter, + result, + ) + } } Err(error) => { let error_msg = env.arena.alloc(format!( @@ -4804,6 +4992,12 @@ fn call_by_name<'a>( None if assigned.module_id() != proc_name.module_id() => { add_needed_external(procs, env, original_fn_var, proc_name); + debug_assert_eq!( + arg_layouts.len(), + field_symbols.len(), + "scroll up a bit for background" + ); + let call = Expr::FunctionCall { call_type: CallType::ByName(proc_name), ret_layout: ret_layout.clone(), diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index f50f383d83..12546eaea5 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -1422,27 +1422,27 @@ mod test_mono { indoc!( r#" procedure Num.15 (#Attr.2, #Attr.3): - let Test.13 = lowlevel NumSub #Attr.2 #Attr.3; - ret Test.13; + let Test.14 = lowlevel NumSub #Attr.2 #Attr.3; + ret Test.14; procedure Num.16 (#Attr.2, #Attr.3): - let Test.11 = lowlevel NumMul #Attr.2 #Attr.3; - ret Test.11; + let Test.12 = lowlevel NumMul #Attr.2 #Attr.3; + ret Test.12; procedure Test.1 (Test.2, Test.3): - jump Test.18 Test.2 Test.3; - joinpoint Test.18 Test.2 Test.3: - let Test.15 = true; - let Test.16 = 0i64; - let Test.17 = lowlevel Eq Test.16 Test.2; - let Test.14 = lowlevel And Test.17 Test.15; - if Test.14 then + jump Test.7 Test.2 Test.3; + joinpoint Test.7 Test.2 Test.3: + let Test.16 = true; + let Test.17 = 0i64; + let Test.18 = lowlevel Eq Test.17 Test.2; + let Test.15 = lowlevel And Test.18 Test.16; + if Test.15 then ret Test.3; else - let Test.12 = 1i64; - let Test.9 = CallByName Num.15 Test.2 Test.12; - let Test.10 = CallByName Num.16 Test.2 Test.3; - jump Test.18 Test.9 Test.10; + let Test.13 = 1i64; + let Test.10 = CallByName Num.15 Test.2 Test.13; + let Test.11 = CallByName Num.16 Test.2 Test.3; + jump Test.7 Test.10 Test.11; procedure Test.0 (): let Test.5 = 10i64; diff --git a/examples/effect/Main.roc b/examples/effect/Main.roc index e91a8d1b48..7c7110693b 100644 --- a/examples/effect/Main.roc +++ b/examples/effect/Main.roc @@ -3,6 +3,5 @@ app Main provides [ main ] imports [ Effect ] main : Effect.Effect {} as Fx main = Effect.putLine "Write a thing!" - |> Effect.after (\{} -> Effect.getLine {}) + |> Effect.after (\{} -> Effect.getLine) |> Effect.after (\line -> Effect.putLine line) - diff --git a/examples/effect/platform/Pkg-Config.roc b/examples/effect/platform/Pkg-Config.roc index 9dd9ec3bc5..a463e3bbb9 100644 --- a/examples/effect/platform/Pkg-Config.roc +++ b/examples/effect/platform/Pkg-Config.roc @@ -6,5 +6,8 @@ platform folkertdev/foo { putChar : Int -> Effect {}, putLine : Str -> Effect {}, - getLine : {} -> Effect Str + getLine : Effect Str } + +mainForHost : Effect {} as Fx +mainForHost = main diff --git a/examples/effect/platform/src/lib.rs b/examples/effect/platform/src/lib.rs index dfe87a6db4..8e1f0448d4 100644 --- a/examples/effect/platform/src/lib.rs +++ b/examples/effect/platform/src/lib.rs @@ -38,7 +38,7 @@ pub fn roc_fx_putLine(line: RocStr) -> () { } #[no_mangle] -pub fn roc_fx_getLine(_: ()) -> RocStr { +pub fn roc_fx_getLine() -> RocStr { use std::io::{self, BufRead}; let stdin = io::stdin();