diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index d4c56854a2..dfa0c3c281 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -362,289 +362,24 @@ pub fn construct_optimization_passes<'a>( (mpm, fpm) } -/// For communication with C (tests and platforms) we need to abide by the C calling convention -/// -/// While small values are just returned like with the fast CC, larger structures need to -/// be written into a pointer (into the callers stack) -enum PassVia { - Register, - Memory, -} - -impl PassVia { - fn from_layout(ptr_bytes: u32, layout: &Layout<'_>) -> Self { - let stack_size = layout.stack_size(ptr_bytes); - let eightbyte = 8; - - if stack_size > 2 * eightbyte { - PassVia::Memory - } else { - PassVia::Register - } - } -} - -/// entry point to roc code; uses the fastcc calling convention -pub fn build_roc_main<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - main_body: &roc_mono::ir::Stmt<'a>, -) -> &'a FunctionValue<'ctx> { - use inkwell::types::BasicType; - - let context = env.context; - let builder = env.builder; - let arena = env.arena; - let ptr_bytes = env.ptr_bytes; - - let return_type = basic_type_from_layout(&arena, context, &layout, ptr_bytes); - let roc_main_fn_name = "$Test.roc_main"; - - // make the roc main function - let roc_main_fn_type = return_type.fn_type(&[], false); - - // Add main to the module. - let roc_main_fn = env - .module - .add_function(roc_main_fn_name, roc_main_fn_type, None); - - // internal function, use fast calling convention - roc_main_fn.set_call_conventions(FAST_CALL_CONV); - - // Add main's body - let basic_block = context.append_basic_block(roc_main_fn, "entry"); - - builder.position_at_end(basic_block); - - // builds the function body (return statement included) - build_exp_stmt( - env, - layout_ids, - &mut Scope::default(), - roc_main_fn, - main_body, - ); - - env.arena.alloc(roc_main_fn) -} - pub fn promote_to_main_function<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, symbol: Symbol, layout: &Layout<'a>, -) -> (&'static str, &'a FunctionValue<'ctx>) { +) -> (&'static str, FunctionValue<'ctx>) { let fn_name = layout_ids .get(symbol, layout) .to_symbol_string(symbol, &env.interns); - let wrapped = env.module.get_function(&fn_name).unwrap(); - - make_main_function_help(env, layout, wrapped) -} - -pub fn make_main_function<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - main_body: &roc_mono::ir::Stmt<'a>, -) -> (&'static str, &'a FunctionValue<'ctx>) { - // internal main function - let roc_main_fn = *build_roc_main(env, layout_ids, layout, main_body); - - make_main_function_help(env, layout, roc_main_fn) -} - -fn make_main_function_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout: &Layout<'a>, - roc_main_fn: FunctionValue<'ctx>, -) -> (&'static str, &'a FunctionValue<'ctx>) { - // build the C calling convention wrapper - use inkwell::types::BasicType; - use PassVia::*; - - let context = env.context; - let builder = env.builder; + let roc_main_fn = env.module.get_function(&fn_name).unwrap(); let main_fn_name = "$Test.main"; - let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); - - let fields = [Layout::Builtin(Builtin::Int64), layout.clone()]; - let main_return_layout = Layout::Struct(&fields); - let main_return_type = block_of_memory(context, &main_return_layout, env.ptr_bytes); - - let register_or_memory = PassVia::from_layout(env.ptr_bytes, &main_return_layout); - - let main_fn_type = match register_or_memory { - Memory => { - let return_value_ptr = context.i64_type().ptr_type(AddressSpace::Generic).into(); - context.void_type().fn_type(&[return_value_ptr], false) - } - Register => main_return_type.fn_type(&[], false), - }; // Add main to the module. - let main_fn = env.module.add_function(main_fn_name, main_fn_type, None); + let main_fn = expose_function_to_host_help(env, roc_main_fn, main_fn_name); - // our exposed main function adheres to the C calling convention - main_fn.set_call_conventions(C_CALL_CONV); - - // Add main's body - let basic_block = context.append_basic_block(main_fn, "entry"); - let then_block = context.append_basic_block(main_fn, "then_block"); - let catch_block = context.append_basic_block(main_fn, "catch_block"); - let cont_block = context.append_basic_block(main_fn, "cont_block"); - - builder.position_at_end(basic_block); - - let result_alloca = builder.build_alloca(main_return_type, "result"); - - // invoke instead of call, so that we can catch any exeptions thrown in Roc code - let call_result = { - let call = builder.build_invoke(roc_main_fn, &[], then_block, catch_block, "call_roc_main"); - call.set_call_convention(FAST_CALL_CONV); - call.try_as_basic_value().left().unwrap() - }; - - // exception handling - { - builder.position_at_end(catch_block); - - let landing_pad_type = { - let exception_ptr = context.i8_type().ptr_type(AddressSpace::Generic).into(); - let selector_value = context.i32_type().into(); - - context.struct_type(&[exception_ptr, selector_value], false) - }; - - let info = builder - .build_catch_all_landing_pad( - &landing_pad_type, - &BasicValueEnum::IntValue(context.i8_type().const_zero()), - context.i8_type().ptr_type(AddressSpace::Generic), - "main_landing_pad", - ) - .into_struct_value(); - - let exception_ptr = builder - .build_extract_value(info, 0, "exception_ptr") - .unwrap(); - - let thrown = cxa_begin_catch(env, exception_ptr); - - let error_msg = { - let exception_type = u8_ptr; - let ptr = builder.build_bitcast( - thrown, - exception_type.ptr_type(AddressSpace::Generic), - "cast", - ); - - builder.build_load(ptr.into_pointer_value(), "error_msg") - }; - - let return_type = context.struct_type(&[context.i64_type().into(), u8_ptr.into()], false); - - let return_value = { - let v1 = return_type.const_zero(); - - // flag is non-zero, indicating failure - let flag = context.i64_type().const_int(1, false); - - let v2 = builder - .build_insert_value(v1, flag, 0, "set_error") - .unwrap(); - - let v3 = builder - .build_insert_value(v2, error_msg, 1, "set_exception") - .unwrap(); - - v3 - }; - - // bitcast result alloca so we can store our concrete type { flag, error_msg } in there - let result_alloca_bitcast = builder - .build_bitcast( - result_alloca, - return_type.ptr_type(AddressSpace::Generic), - "result_alloca_bitcast", - ) - .into_pointer_value(); - - // store our return value - builder.build_store(result_alloca_bitcast, return_value); - - cxa_end_catch(env); - - builder.build_unconditional_branch(cont_block); - } - - { - builder.position_at_end(then_block); - - let actual_return_type = - basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes); - let return_type = - context.struct_type(&[context.i64_type().into(), actual_return_type], false); - - let return_value = { - let v1 = return_type.const_zero(); - - let v2 = builder - .build_insert_value(v1, context.i64_type().const_zero(), 0, "set_no_error") - .unwrap(); - let v3 = builder - .build_insert_value(v2, call_result, 1, "set_call_result") - .unwrap(); - - v3 - }; - - let ptr = builder.build_bitcast( - result_alloca, - return_type.ptr_type(AddressSpace::Generic), - "name", - ); - builder.build_store(ptr.into_pointer_value(), return_value); - - builder.build_unconditional_branch(cont_block); - } - - { - builder.position_at_end(cont_block); - - let result = builder.build_load(result_alloca, "result"); - - match register_or_memory { - Memory => { - // write the result into the supplied pointer - let ptr_return_type = main_return_type.ptr_type(AddressSpace::Generic); - - let ptr_as_int = main_fn.get_first_param().unwrap(); - - let ptr = builder.build_bitcast(ptr_as_int, ptr_return_type, "caller_ptr"); - - builder.build_store(ptr.into_pointer_value(), result); - - // this is a void function, therefore return None - builder.build_return(None); - } - Register => { - // construct a normal return - // values are passed to the caller via registers - builder.build_return(Some(&result)); - } - } - } - - // MUST set the personality at the very end; - // doing it earlier can cause the personality to be ignored - let personality_func = get_gxx_personality_v0(env); - main_fn.set_personality_function(personality_func); - - (main_fn_name, env.arena.alloc(main_fn)) + (main_fn_name, main_fn) } fn get_inplace_from_layout(layout: &Layout<'_>) -> InPlace { @@ -1875,9 +1610,19 @@ fn expose_function_to_host<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, roc_function: FunctionValue<'ctx>, ) { + let c_function_name: String = format!("{}_exposed", roc_function.get_name().to_str().unwrap()); + + expose_function_to_host_help(env, roc_function, &c_function_name); +} + +fn expose_function_to_host_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + roc_function: FunctionValue<'ctx>, + c_function_name: &str, +) -> FunctionValue<'ctx> { use inkwell::types::BasicType; - let roc_wrapper_function = make_exception_catching_wrapper(env, roc_function); + let roc_wrapper_function = make_exception_catcher(env, roc_function); let roc_function_type = roc_wrapper_function.get_type(); @@ -1888,13 +1633,10 @@ fn expose_function_to_host<'a, 'ctx, 'env>( argument_types.push(output_type.into()); let c_function_type = env.context.void_type().fn_type(&argument_types, false); - let c_function_name: String = format!("{}_exposed", roc_function.get_name().to_str().unwrap()); - let c_function = env.module.add_function( - c_function_name.as_str(), - c_function_type, - Some(Linkage::External), - ); + let c_function = + env.module + .add_function(c_function_name, c_function_type, Some(Linkage::External)); // STEP 2: build the exposed function's body let builder = env.builder; @@ -1942,6 +1684,8 @@ fn expose_function_to_host<'a, 'ctx, 'env>( let size: BasicValueEnum = return_type.size_of().unwrap().into(); builder.build_return(Some(&size)); + + c_function } fn invoke_and_catch<'a, 'ctx, 'env, F, T>( @@ -2095,9 +1839,19 @@ where result } +fn make_exception_catcher<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + roc_function: FunctionValue<'ctx>, +) -> FunctionValue<'ctx> { + let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap()); + + make_exception_catching_wrapper(env, roc_function, &wrapper_function_name) +} + fn make_exception_catching_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, roc_function: FunctionValue<'ctx>, + wrapper_function_name: &str, ) -> FunctionValue<'ctx> { // build the C calling convention wrapper @@ -2107,8 +1861,6 @@ fn make_exception_catching_wrapper<'a, 'ctx, 'env>( let roc_function_type = roc_function.get_type(); let argument_types = roc_function_type.get_param_types(); - let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap()); - let wrapper_return_type = context.struct_type( &[ context.i64_type().into(), @@ -2328,70 +2080,6 @@ pub fn build_closure_caller<'a, 'ctx, 'env>( builder.build_return(Some(&size)); } -#[allow(dead_code)] -pub fn build_closure_caller_old<'a, 'ctx, 'env>( - env: &'a Env<'a, 'ctx, 'env>, - closure_function: FunctionValue<'ctx>, -) { - let context = env.context; - let builder = env.builder; - // asuming the closure has type `a, b, closure_data -> c` - // change that into `a, b, *const closure_data, *mut output -> ()` - - // a function `a, b, closure_data -> RocCallResult` - let wrapped_function = make_exception_catching_wrapper(env, closure_function); - - let closure_function_type = closure_function.get_type(); - let wrapped_function_type = wrapped_function.get_type(); - - let mut arguments = closure_function_type.get_param_types(); - - // require that the closure data is passed by reference - let closure_data_type = arguments.pop().unwrap(); - let closure_data_ptr_type = get_ptr_type(&closure_data_type, AddressSpace::Generic); - arguments.push(closure_data_ptr_type.into()); - - // require that a pointer is passed in to write the result into - let output_type = get_ptr_type( - &wrapped_function_type.get_return_type().unwrap(), - AddressSpace::Generic, - ); - arguments.push(output_type.into()); - - let caller_function_type = env.context.void_type().fn_type(&arguments, false); - let caller_function_name: String = - format!("{}_caller", closure_function.get_name().to_str().unwrap()); - - let caller_function = env.module.add_function( - caller_function_name.as_str(), - caller_function_type, - Some(Linkage::External), - ); - - caller_function.set_call_conventions(C_CALL_CONV); - - let entry = context.append_basic_block(caller_function, "entry"); - - builder.position_at_end(entry); - - let mut parameters = caller_function.get_params(); - let output = parameters.pop().unwrap(); - let closure_data_ptr = parameters.pop().unwrap(); - - let closure_data = - builder.build_load(closure_data_ptr.into_pointer_value(), "load_closure_data"); - parameters.push(closure_data); - - let call = builder.build_call(wrapped_function, ¶meters, "call_wrapped_function"); - call.set_call_convention(FAST_CALL_CONV); - - let result = call.try_as_basic_value().left().unwrap(); - - builder.build_store(output.into_pointer_value(), result); - - builder.build_return(None); -} - pub fn build_proc<'a, 'ctx, 'env>( env: &'a Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, diff --git a/compiler/gen/src/run_roc.rs b/compiler/gen/src/run_roc.rs index c9b0aff8d3..024a5198ca 100644 --- a/compiler/gen/src/run_roc.rs +++ b/compiler/gen/src/run_roc.rs @@ -36,15 +36,20 @@ macro_rules! run_jit_function { ($lib: expr, $main_fn_name: expr, $ty:ty, $transform:expr, $errors:expr) => {{ use inkwell::context::Context; use roc_gen::run_roc::RocCallResult; + use std::mem::MaybeUninit; unsafe { - let main: libloading::Symbol RocCallResult<$ty>> = $lib - .get($main_fn_name.as_bytes()) - .ok() - .ok_or(format!("Unable to JIT compile `{}`", $main_fn_name)) - .expect("errored"); + let main: libloading::Symbol) -> ()> = + $lib.get($main_fn_name.as_bytes()) + .ok() + .ok_or(format!("Unable to JIT compile `{}`", $main_fn_name)) + .expect("errored"); - match main().into() { + let mut result = MaybeUninit::uninit(); + + main(result.as_mut_ptr()); + + match result.assume_init().into() { Ok(success) => { // only if there are no exceptions thrown, check for errors assert_eq!( diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 6a6a9dbc29..b0dacec440 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -1088,7 +1088,6 @@ mod gen_primitives { } #[test] - #[ignore] fn return_wrapped_closure() { assert_non_opt_evals_to!( indoc!( @@ -1101,7 +1100,7 @@ mod gen_primitives { foo = x = 5 - @Effect \{} -> if x > 3 then {} else {} + @Effect (\{} -> if x > 3 then {} else {}) main : Effect {} main = foo