diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 7ec2d58b15..8d6fb22760 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -1844,11 +1844,13 @@ pub fn create_entry_block_alloca<'a, 'ctx>( fn expose_function_to_host<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - roc_function: &FunctionValue<'ctx>, + roc_function: FunctionValue<'ctx>, ) { use inkwell::types::BasicType; - let roc_function_type = roc_function.get_type(); + let roc_wrapper_function = make_exception_catching_wrapper(env, roc_function); + + let roc_function_type = roc_wrapper_function.get_type(); // STEP 1: turn `f : a,b,c -> d` into `f : a,b,c, &d -> {}` let mut argument_types = roc_function_type.get_param_types(); @@ -1879,8 +1881,9 @@ fn expose_function_to_host<'a, 'ctx, 'env>( let args = &args[..args.len() - 1]; debug_assert_eq!(args.len(), roc_function.get_params().len()); + debug_assert_eq!(args.len(), roc_wrapper_function.get_params().len()); - let call_wrapped = builder.build_call(roc_function.clone(), args, "call_wrapped_function"); + let call_wrapped = builder.build_call(roc_wrapper_function, args, "call_wrapped_function"); call_wrapped.set_call_convention(FAST_CALL_CONV); let call_result = call_wrapped.try_as_basic_value().left().unwrap(); @@ -1912,6 +1915,173 @@ fn expose_function_to_host<'a, 'ctx, 'env>( builder.build_return(Some(&size)); } +fn make_exception_catching_wrapper<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + roc_function: FunctionValue<'ctx>, +) -> FunctionValue<'ctx> { + // build the C calling convention wrapper + + let context = env.context; + let builder = env.builder; + + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let roc_function_type = roc_function.get_type(); + let argument_types = roc_function_type.get_param_types(); + + let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap()); + + let wrapper_return_type = context.struct_type( + &[ + context.i64_type().into(), + roc_function_type.get_return_type().unwrap(), + ], + false, + ); + + let wrapper_function_type = wrapper_return_type.fn_type(&argument_types, false); + + // Add main to the module. + let wrapper_function = + env.module + .add_function(&wrapper_function_name, wrapper_function_type, None); + + // our exposed main function adheres to the C calling convention + wrapper_function.set_call_conventions(FAST_CALL_CONV); + + // Add main's body + let basic_block = context.append_basic_block(wrapper_function, "entry"); + let then_block = context.append_basic_block(wrapper_function, "then_block"); + let catch_block = context.append_basic_block(wrapper_function, "catch_block"); + let cont_block = context.append_basic_block(wrapper_function, "cont_block"); + + builder.position_at_end(basic_block); + + let result_alloca = builder.build_alloca(wrapper_return_type, "result"); + + // invoke instead of call, so that we can catch any exeptions thrown in Roc code + let call_result = { + let call = + builder.build_invoke(roc_function, &[], then_block, catch_block, "call_roc_main"); + call.set_call_convention(FAST_CALL_CONV); + call.try_as_basic_value().left().unwrap() + }; + + // exception handling + { + builder.position_at_end(catch_block); + + let landing_pad_type = { + let exception_ptr = context.i8_type().ptr_type(AddressSpace::Generic).into(); + let selector_value = context.i32_type().into(); + + context.struct_type(&[exception_ptr, selector_value], false) + }; + + let info = builder + .build_catch_all_landing_pad( + &landing_pad_type, + &BasicValueEnum::IntValue(context.i8_type().const_zero()), + context.i8_type().ptr_type(AddressSpace::Generic), + "main_landing_pad", + ) + .into_struct_value(); + + let exception_ptr = builder + .build_extract_value(info, 0, "exception_ptr") + .unwrap(); + + let thrown = cxa_begin_catch(env, exception_ptr); + + let error_msg = { + let exception_type = u8_ptr; + let ptr = builder.build_bitcast( + thrown, + exception_type.ptr_type(AddressSpace::Generic), + "cast", + ); + + builder.build_load(ptr.into_pointer_value(), "error_msg") + }; + + let return_type = context.struct_type(&[context.i64_type().into(), u8_ptr.into()], false); + + let return_value = { + let v1 = return_type.const_zero(); + + // flag is non-zero, indicating failure + let flag = context.i64_type().const_int(1, false); + + let v2 = builder + .build_insert_value(v1, flag, 0, "set_error") + .unwrap(); + + let v3 = builder + .build_insert_value(v2, error_msg, 1, "set_exception") + .unwrap(); + + v3 + }; + + // bitcast result alloca so we can store our concrete type { flag, error_msg } in there + let result_alloca_bitcast = builder + .build_bitcast( + result_alloca, + return_type.ptr_type(AddressSpace::Generic), + "result_alloca_bitcast", + ) + .into_pointer_value(); + + // store our return value + builder.build_store(result_alloca_bitcast, return_value); + + cxa_end_catch(env); + + builder.build_unconditional_branch(cont_block); + } + + { + builder.position_at_end(then_block); + + let return_value = { + let v1 = wrapper_return_type.const_zero(); + + let v2 = builder + .build_insert_value(v1, context.i64_type().const_zero(), 0, "set_no_error") + .unwrap(); + let v3 = builder + .build_insert_value(v2, call_result, 1, "set_call_result") + .unwrap(); + + v3 + }; + + let ptr = builder.build_bitcast( + result_alloca, + wrapper_return_type.ptr_type(AddressSpace::Generic), + "name", + ); + builder.build_store(ptr.into_pointer_value(), return_value); + + builder.build_unconditional_branch(cont_block); + } + + { + builder.position_at_end(cont_block); + + let result = builder.build_load(result_alloca, "result"); + + builder.build_return(Some(&result)); + } + + // MUST set the personality at the very end; + // doing it earlier can cause the personality to be ignored + let personality_func = get_gxx_personality_v0(env); + wrapper_function.set_personality_function(personality_func); + + wrapper_function +} + pub fn build_proc_header<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -1944,7 +2114,7 @@ pub fn build_proc_header<'a, 'ctx, 'env>( fn_val.set_call_conventions(FAST_CALL_CONV); if env.exposed_to_host.contains(&symbol) { - expose_function_to_host(env, &fn_val); + expose_function_to_host(env, fn_val); } fn_val diff --git a/examples/closure/platform/src/lib.rs b/examples/closure/platform/src/lib.rs index c57f5a8502..d19e5ab134 100644 --- a/examples/closure/platform/src/lib.rs +++ b/examples/closure/platform/src/lib.rs @@ -1,5 +1,8 @@ +use std::ffi::CString; use std::mem::MaybeUninit; +use std::os::raw::c_char; use std::time::SystemTime; +use RocCallResult::*; extern "C" { #[link_name = "closure_1_exposed"] @@ -13,14 +16,19 @@ extern "C" { pub fn rust_main() -> isize { println!("Running Roc closure"); let start_time = SystemTime::now(); - let (function_pointer, closure_data) = unsafe { - let mut output: MaybeUninit<(fn(i64) -> i64, i64)> = MaybeUninit::uninit(); + + let size = unsafe { closure_size() } as usize; + let roc_closure = unsafe { + let mut output: MaybeUninit i64, i64)>> = MaybeUninit::uninit(); closure(output.as_mut_ptr() as _); - output.assume_init() + match output.assume_init().into() { + Ok((function_pointer, closure_data)) => move || function_pointer(closure_data), + Err(msg) => panic!("Roc failed with message: {}", msg), + } }; - let answer = function_pointer(closure_data); + let answer = roc_closure(); let end_time = SystemTime::now(); let duration = end_time.duration_since(start_time).unwrap(); @@ -31,8 +39,30 @@ pub fn rust_main() -> isize { answer ); - println!("closure size {:?}", unsafe { closure_size() }); - // Exit code 0 } + +#[repr(u64)] +pub enum RocCallResult { + Success(T), + Failure(*mut c_char), +} + +impl Into> for RocCallResult { + fn into(self) -> Result { + match self { + Success(value) => Ok(value), + Failure(failure) => Err({ + let raw = unsafe { CString::from_raw(failure) }; + + let result = format!("{:?}", raw); + + // make sure rust does not try to free the Roc string + std::mem::forget(raw); + + result + }), + } + } +}