diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 74e5348653..9571783bc5 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -1395,23 +1395,122 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( env.builder.build_store(alloca, val); - // Make a new scope which includes the binding we just encountered. - // This should be done *after* compiling the bound expr, since any - // recursive (in the LetRec sense) bindings should already have - // been extracted as procedures. Nothing in here should need to - // access itself! - // scope = scope.clone(); + let arguments = call.arguments; + match call.call_type { + CallType::ByName { + name, + ref full_layout, + .. + } => { + let function_value = function_value_by_name(env, layout_ids, full_layout, name); - let pass_branch = { - scope.insert(*symbol, (layout.clone(), alloca)); + let mut arg_vals: Vec = + Vec::with_capacity_in(arguments.len(), env.arena); - let result = build_exp_stmt(env, layout_ids, scope, parent, pass); + for arg in arguments.iter() { + arg_vals.push(load_symbol(env, scope, arg)); + } - scope.remove(symbol); - result - }; + let pass_block = context.append_basic_block(parent, "invoke_pass"); + let fail_block = context.append_basic_block(parent, "invoke_fail"); - pass_branch + { + env.builder.position_at_end(pass_block); + + scope.insert(*symbol, (layout.clone(), alloca)); + + build_exp_stmt(env, layout_ids, scope, parent, pass); + + scope.remove(symbol); + } + + { + env.builder.position_at_end(fail_block); + + build_exp_stmt(env, layout_ids, scope, parent, fail); + } + + let call = env.builder.build_invoke( + function_value, + arg_vals.as_slice(), + pass_block, + fail_block, + "tmp", + ); + + if env.exposed_to_host.contains(&name) { + // If this is an external-facing function, use the C calling convention. + call.set_call_convention(C_CALL_CONV); + } else { + // If it's an internal-only function, use the fast calling convention. + call.set_call_convention(FAST_CALL_CONV); + } + + call.try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) + } + CallType::ByPointer { name, .. } => { + let sub_expr = load_symbol(env, scope, &name); + + let mut arg_vals: Vec = + Vec::with_capacity_in(arguments.len(), env.arena); + + for arg in arguments.iter() { + arg_vals.push(load_symbol(env, scope, arg)); + } + + let pass_block = context.append_basic_block(parent, "invoke_pass"); + let fail_block = context.append_basic_block(parent, "invoke_fail"); + + { + env.builder.position_at_end(pass_block); + + scope.insert(*symbol, (layout.clone(), alloca)); + + build_exp_stmt(env, layout_ids, scope, parent, pass); + + scope.remove(symbol); + } + + { + env.builder.position_at_end(fail_block); + + build_exp_stmt(env, layout_ids, scope, parent, fail); + } + + let call = match sub_expr { + BasicValueEnum::PointerValue(ptr) => env.builder.build_invoke( + ptr, + arg_vals.as_slice(), + pass_block, + fail_block, + "tmp", + ), + non_ptr => { + panic!( + "Tried to call by pointer, but encountered a non-pointer: {:?}", + non_ptr + ); + } + }; + + if env.exposed_to_host.contains(&name) { + // If this is an external-facing function, use the C calling convention. + call.set_call_convention(C_CALL_CONV); + } else { + // If it's an internal-only function, use the fast calling convention. + call.set_call_convention(FAST_CALL_CONV); + } + + call.try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) + } + _ => { + todo!() + } + } } Unreachable => { @@ -2579,6 +2678,29 @@ pub fn verify_fn(fn_val: FunctionValue<'_>) { } } +fn function_value_by_name<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + layout: &Layout<'a>, + symbol: Symbol, +) -> FunctionValue<'ctx> { + let fn_name = layout_ids + .get(symbol, layout) + .to_symbol_string(symbol, &env.interns); + let fn_name = fn_name.as_str(); + + env.module.get_function(fn_name).unwrap_or_else(|| { + if symbol.is_builtin() { + panic!("Unrecognized builtin function: {:?}", fn_name) + } else { + panic!( + "Unrecognized non-builtin function: {:?} (symbol: {:?}, layout: {:?})", + fn_name, symbol, layout + ) + } + }) +} + // #[allow(clippy::cognitive_complexity)] #[inline(always)] fn call_with_args<'a, 'ctx, 'env>( @@ -2589,21 +2711,7 @@ fn call_with_args<'a, 'ctx, 'env>( _parent: FunctionValue<'ctx>, args: &[BasicValueEnum<'ctx>], ) -> BasicValueEnum<'ctx> { - let fn_name = layout_ids - .get(symbol, layout) - .to_symbol_string(symbol, &env.interns); - let fn_name = fn_name.as_str(); - - let fn_val = env.module.get_function(fn_name).unwrap_or_else(|| { - if symbol.is_builtin() { - panic!("Unrecognized builtin function: {:?}", fn_name) - } else { - panic!( - "Unrecognized non-builtin function: {:?} (symbol: {:?}, layout: {:?})", - fn_name, symbol, layout - ) - } - }); + let fn_val = function_value_by_name(env, layout_ids, layout, symbol); let call = env.builder.build_call(fn_val, args, "call");