mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-28 22:34:45 +00:00
Merge pull request #666 from rtfeldman/return-function
Return function pointers and closures
This commit is contained in:
commit
54de538952
9 changed files with 397 additions and 461 deletions
|
@ -362,289 +362,24 @@ pub fn construct_optimization_passes<'a>(
|
|||
(mpm, fpm)
|
||||
}
|
||||
|
||||
/// For communication with C (tests and platforms) we need to abide by the C calling convention
|
||||
///
|
||||
/// While small values are just returned like with the fast CC, larger structures need to
|
||||
/// be written into a pointer (into the callers stack)
|
||||
enum PassVia {
|
||||
Register,
|
||||
Memory,
|
||||
}
|
||||
|
||||
impl PassVia {
|
||||
fn from_layout(ptr_bytes: u32, layout: &Layout<'_>) -> Self {
|
||||
let stack_size = layout.stack_size(ptr_bytes);
|
||||
let eightbyte = 8;
|
||||
|
||||
if stack_size > 2 * eightbyte {
|
||||
PassVia::Memory
|
||||
} else {
|
||||
PassVia::Register
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// entry point to roc code; uses the fastcc calling convention
|
||||
pub fn build_roc_main<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
layout: &Layout<'a>,
|
||||
main_body: &roc_mono::ir::Stmt<'a>,
|
||||
) -> &'a FunctionValue<'ctx> {
|
||||
use inkwell::types::BasicType;
|
||||
|
||||
let context = env.context;
|
||||
let builder = env.builder;
|
||||
let arena = env.arena;
|
||||
let ptr_bytes = env.ptr_bytes;
|
||||
|
||||
let return_type = basic_type_from_layout(&arena, context, &layout, ptr_bytes);
|
||||
let roc_main_fn_name = "$Test.roc_main";
|
||||
|
||||
// make the roc main function
|
||||
let roc_main_fn_type = return_type.fn_type(&[], false);
|
||||
|
||||
// Add main to the module.
|
||||
let roc_main_fn = env
|
||||
.module
|
||||
.add_function(roc_main_fn_name, roc_main_fn_type, None);
|
||||
|
||||
// internal function, use fast calling convention
|
||||
roc_main_fn.set_call_conventions(FAST_CALL_CONV);
|
||||
|
||||
// Add main's body
|
||||
let basic_block = context.append_basic_block(roc_main_fn, "entry");
|
||||
|
||||
builder.position_at_end(basic_block);
|
||||
|
||||
// builds the function body (return statement included)
|
||||
build_exp_stmt(
|
||||
env,
|
||||
layout_ids,
|
||||
&mut Scope::default(),
|
||||
roc_main_fn,
|
||||
main_body,
|
||||
);
|
||||
|
||||
env.arena.alloc(roc_main_fn)
|
||||
}
|
||||
|
||||
pub fn promote_to_main_function<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
symbol: Symbol,
|
||||
layout: &Layout<'a>,
|
||||
) -> (&'static str, &'a FunctionValue<'ctx>) {
|
||||
) -> (&'static str, FunctionValue<'ctx>) {
|
||||
let fn_name = layout_ids
|
||||
.get(symbol, layout)
|
||||
.to_symbol_string(symbol, &env.interns);
|
||||
|
||||
let wrapped = env.module.get_function(&fn_name).unwrap();
|
||||
|
||||
make_main_function_help(env, layout, wrapped)
|
||||
}
|
||||
|
||||
pub fn make_main_function<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
layout: &Layout<'a>,
|
||||
main_body: &roc_mono::ir::Stmt<'a>,
|
||||
) -> (&'static str, &'a FunctionValue<'ctx>) {
|
||||
// internal main function
|
||||
let roc_main_fn = *build_roc_main(env, layout_ids, layout, main_body);
|
||||
|
||||
make_main_function_help(env, layout, roc_main_fn)
|
||||
}
|
||||
|
||||
fn make_main_function_help<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
layout: &Layout<'a>,
|
||||
roc_main_fn: FunctionValue<'ctx>,
|
||||
) -> (&'static str, &'a FunctionValue<'ctx>) {
|
||||
// build the C calling convention wrapper
|
||||
use inkwell::types::BasicType;
|
||||
use PassVia::*;
|
||||
|
||||
let context = env.context;
|
||||
let builder = env.builder;
|
||||
let roc_main_fn = env.module.get_function(&fn_name).unwrap();
|
||||
|
||||
let main_fn_name = "$Test.main";
|
||||
let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic);
|
||||
|
||||
let fields = [Layout::Builtin(Builtin::Int64), layout.clone()];
|
||||
let main_return_layout = Layout::Struct(&fields);
|
||||
let main_return_type = block_of_memory(context, &main_return_layout, env.ptr_bytes);
|
||||
|
||||
let register_or_memory = PassVia::from_layout(env.ptr_bytes, &main_return_layout);
|
||||
|
||||
let main_fn_type = match register_or_memory {
|
||||
Memory => {
|
||||
let return_value_ptr = context.i64_type().ptr_type(AddressSpace::Generic).into();
|
||||
context.void_type().fn_type(&[return_value_ptr], false)
|
||||
}
|
||||
Register => main_return_type.fn_type(&[], false),
|
||||
};
|
||||
|
||||
// Add main to the module.
|
||||
let main_fn = env.module.add_function(main_fn_name, main_fn_type, None);
|
||||
let main_fn = expose_function_to_host_help(env, roc_main_fn, main_fn_name);
|
||||
|
||||
// our exposed main function adheres to the C calling convention
|
||||
main_fn.set_call_conventions(C_CALL_CONV);
|
||||
|
||||
// Add main's body
|
||||
let basic_block = context.append_basic_block(main_fn, "entry");
|
||||
let then_block = context.append_basic_block(main_fn, "then_block");
|
||||
let catch_block = context.append_basic_block(main_fn, "catch_block");
|
||||
let cont_block = context.append_basic_block(main_fn, "cont_block");
|
||||
|
||||
builder.position_at_end(basic_block);
|
||||
|
||||
let result_alloca = builder.build_alloca(main_return_type, "result");
|
||||
|
||||
// invoke instead of call, so that we can catch any exeptions thrown in Roc code
|
||||
let call_result = {
|
||||
let call = builder.build_invoke(roc_main_fn, &[], then_block, catch_block, "call_roc_main");
|
||||
call.set_call_convention(FAST_CALL_CONV);
|
||||
call.try_as_basic_value().left().unwrap()
|
||||
};
|
||||
|
||||
// exception handling
|
||||
{
|
||||
builder.position_at_end(catch_block);
|
||||
|
||||
let landing_pad_type = {
|
||||
let exception_ptr = context.i8_type().ptr_type(AddressSpace::Generic).into();
|
||||
let selector_value = context.i32_type().into();
|
||||
|
||||
context.struct_type(&[exception_ptr, selector_value], false)
|
||||
};
|
||||
|
||||
let info = builder
|
||||
.build_catch_all_landing_pad(
|
||||
&landing_pad_type,
|
||||
&BasicValueEnum::IntValue(context.i8_type().const_zero()),
|
||||
context.i8_type().ptr_type(AddressSpace::Generic),
|
||||
"main_landing_pad",
|
||||
)
|
||||
.into_struct_value();
|
||||
|
||||
let exception_ptr = builder
|
||||
.build_extract_value(info, 0, "exception_ptr")
|
||||
.unwrap();
|
||||
|
||||
let thrown = cxa_begin_catch(env, exception_ptr);
|
||||
|
||||
let error_msg = {
|
||||
let exception_type = u8_ptr;
|
||||
let ptr = builder.build_bitcast(
|
||||
thrown,
|
||||
exception_type.ptr_type(AddressSpace::Generic),
|
||||
"cast",
|
||||
);
|
||||
|
||||
builder.build_load(ptr.into_pointer_value(), "error_msg")
|
||||
};
|
||||
|
||||
let return_type = context.struct_type(&[context.i64_type().into(), u8_ptr.into()], false);
|
||||
|
||||
let return_value = {
|
||||
let v1 = return_type.const_zero();
|
||||
|
||||
// flag is non-zero, indicating failure
|
||||
let flag = context.i64_type().const_int(1, false);
|
||||
|
||||
let v2 = builder
|
||||
.build_insert_value(v1, flag, 0, "set_error")
|
||||
.unwrap();
|
||||
|
||||
let v3 = builder
|
||||
.build_insert_value(v2, error_msg, 1, "set_exception")
|
||||
.unwrap();
|
||||
|
||||
v3
|
||||
};
|
||||
|
||||
// bitcast result alloca so we can store our concrete type { flag, error_msg } in there
|
||||
let result_alloca_bitcast = builder
|
||||
.build_bitcast(
|
||||
result_alloca,
|
||||
return_type.ptr_type(AddressSpace::Generic),
|
||||
"result_alloca_bitcast",
|
||||
)
|
||||
.into_pointer_value();
|
||||
|
||||
// store our return value
|
||||
builder.build_store(result_alloca_bitcast, return_value);
|
||||
|
||||
cxa_end_catch(env);
|
||||
|
||||
builder.build_unconditional_branch(cont_block);
|
||||
}
|
||||
|
||||
{
|
||||
builder.position_at_end(then_block);
|
||||
|
||||
let actual_return_type =
|
||||
basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes);
|
||||
let return_type =
|
||||
context.struct_type(&[context.i64_type().into(), actual_return_type], false);
|
||||
|
||||
let return_value = {
|
||||
let v1 = return_type.const_zero();
|
||||
|
||||
let v2 = builder
|
||||
.build_insert_value(v1, context.i64_type().const_zero(), 0, "set_no_error")
|
||||
.unwrap();
|
||||
let v3 = builder
|
||||
.build_insert_value(v2, call_result, 1, "set_call_result")
|
||||
.unwrap();
|
||||
|
||||
v3
|
||||
};
|
||||
|
||||
let ptr = builder.build_bitcast(
|
||||
result_alloca,
|
||||
return_type.ptr_type(AddressSpace::Generic),
|
||||
"name",
|
||||
);
|
||||
builder.build_store(ptr.into_pointer_value(), return_value);
|
||||
|
||||
builder.build_unconditional_branch(cont_block);
|
||||
}
|
||||
|
||||
{
|
||||
builder.position_at_end(cont_block);
|
||||
|
||||
let result = builder.build_load(result_alloca, "result");
|
||||
|
||||
match register_or_memory {
|
||||
Memory => {
|
||||
// write the result into the supplied pointer
|
||||
let ptr_return_type = main_return_type.ptr_type(AddressSpace::Generic);
|
||||
|
||||
let ptr_as_int = main_fn.get_first_param().unwrap();
|
||||
|
||||
let ptr = builder.build_bitcast(ptr_as_int, ptr_return_type, "caller_ptr");
|
||||
|
||||
builder.build_store(ptr.into_pointer_value(), result);
|
||||
|
||||
// this is a void function, therefore return None
|
||||
builder.build_return(None);
|
||||
}
|
||||
Register => {
|
||||
// construct a normal return
|
||||
// values are passed to the caller via registers
|
||||
builder.build_return(Some(&result));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MUST set the personality at the very end;
|
||||
// doing it earlier can cause the personality to be ignored
|
||||
let personality_func = get_gxx_personality_v0(env);
|
||||
main_fn.set_personality_function(personality_func);
|
||||
|
||||
(main_fn_name, env.arena.alloc(main_fn))
|
||||
(main_fn_name, main_fn)
|
||||
}
|
||||
|
||||
fn get_inplace_from_layout(layout: &Layout<'_>) -> InPlace {
|
||||
|
@ -1875,9 +1610,19 @@ fn expose_function_to_host<'a, 'ctx, 'env>(
|
|||
env: &Env<'a, 'ctx, 'env>,
|
||||
roc_function: FunctionValue<'ctx>,
|
||||
) {
|
||||
let c_function_name: String = format!("{}_exposed", roc_function.get_name().to_str().unwrap());
|
||||
|
||||
expose_function_to_host_help(env, roc_function, &c_function_name);
|
||||
}
|
||||
|
||||
fn expose_function_to_host_help<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
roc_function: FunctionValue<'ctx>,
|
||||
c_function_name: &str,
|
||||
) -> FunctionValue<'ctx> {
|
||||
use inkwell::types::BasicType;
|
||||
|
||||
let roc_wrapper_function = make_exception_catching_wrapper(env, roc_function);
|
||||
let roc_wrapper_function = make_exception_catcher(env, roc_function);
|
||||
|
||||
let roc_function_type = roc_wrapper_function.get_type();
|
||||
|
||||
|
@ -1888,13 +1633,10 @@ fn expose_function_to_host<'a, 'ctx, 'env>(
|
|||
argument_types.push(output_type.into());
|
||||
|
||||
let c_function_type = env.context.void_type().fn_type(&argument_types, false);
|
||||
let c_function_name: String = format!("{}_exposed", roc_function.get_name().to_str().unwrap());
|
||||
|
||||
let c_function = env.module.add_function(
|
||||
c_function_name.as_str(),
|
||||
c_function_type,
|
||||
Some(Linkage::External),
|
||||
);
|
||||
let c_function =
|
||||
env.module
|
||||
.add_function(c_function_name, c_function_type, Some(Linkage::External));
|
||||
|
||||
// STEP 2: build the exposed function's body
|
||||
let builder = env.builder;
|
||||
|
@ -1942,6 +1684,8 @@ fn expose_function_to_host<'a, 'ctx, 'env>(
|
|||
|
||||
let size: BasicValueEnum = return_type.size_of().unwrap().into();
|
||||
builder.build_return(Some(&size));
|
||||
|
||||
c_function
|
||||
}
|
||||
|
||||
fn invoke_and_catch<'a, 'ctx, 'env, F, T>(
|
||||
|
@ -2095,9 +1839,19 @@ where
|
|||
result
|
||||
}
|
||||
|
||||
fn make_exception_catcher<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
roc_function: FunctionValue<'ctx>,
|
||||
) -> FunctionValue<'ctx> {
|
||||
let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap());
|
||||
|
||||
make_exception_catching_wrapper(env, roc_function, &wrapper_function_name)
|
||||
}
|
||||
|
||||
fn make_exception_catching_wrapper<'a, 'ctx, 'env>(
|
||||
env: &Env<'a, 'ctx, 'env>,
|
||||
roc_function: FunctionValue<'ctx>,
|
||||
wrapper_function_name: &str,
|
||||
) -> FunctionValue<'ctx> {
|
||||
// build the C calling convention wrapper
|
||||
|
||||
|
@ -2107,8 +1861,6 @@ fn make_exception_catching_wrapper<'a, 'ctx, 'env>(
|
|||
let roc_function_type = roc_function.get_type();
|
||||
let argument_types = roc_function_type.get_param_types();
|
||||
|
||||
let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap());
|
||||
|
||||
let wrapper_return_type = context.struct_type(
|
||||
&[
|
||||
context.i64_type().into(),
|
||||
|
@ -2328,70 +2080,6 @@ pub fn build_closure_caller<'a, 'ctx, 'env>(
|
|||
builder.build_return(Some(&size));
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn build_closure_caller_old<'a, 'ctx, 'env>(
|
||||
env: &'a Env<'a, 'ctx, 'env>,
|
||||
closure_function: FunctionValue<'ctx>,
|
||||
) {
|
||||
let context = env.context;
|
||||
let builder = env.builder;
|
||||
// asuming the closure has type `a, b, closure_data -> c`
|
||||
// change that into `a, b, *const closure_data, *mut output -> ()`
|
||||
|
||||
// a function `a, b, closure_data -> RocCallResult<c>`
|
||||
let wrapped_function = make_exception_catching_wrapper(env, closure_function);
|
||||
|
||||
let closure_function_type = closure_function.get_type();
|
||||
let wrapped_function_type = wrapped_function.get_type();
|
||||
|
||||
let mut arguments = closure_function_type.get_param_types();
|
||||
|
||||
// require that the closure data is passed by reference
|
||||
let closure_data_type = arguments.pop().unwrap();
|
||||
let closure_data_ptr_type = get_ptr_type(&closure_data_type, AddressSpace::Generic);
|
||||
arguments.push(closure_data_ptr_type.into());
|
||||
|
||||
// require that a pointer is passed in to write the result into
|
||||
let output_type = get_ptr_type(
|
||||
&wrapped_function_type.get_return_type().unwrap(),
|
||||
AddressSpace::Generic,
|
||||
);
|
||||
arguments.push(output_type.into());
|
||||
|
||||
let caller_function_type = env.context.void_type().fn_type(&arguments, false);
|
||||
let caller_function_name: String =
|
||||
format!("{}_caller", closure_function.get_name().to_str().unwrap());
|
||||
|
||||
let caller_function = env.module.add_function(
|
||||
caller_function_name.as_str(),
|
||||
caller_function_type,
|
||||
Some(Linkage::External),
|
||||
);
|
||||
|
||||
caller_function.set_call_conventions(C_CALL_CONV);
|
||||
|
||||
let entry = context.append_basic_block(caller_function, "entry");
|
||||
|
||||
builder.position_at_end(entry);
|
||||
|
||||
let mut parameters = caller_function.get_params();
|
||||
let output = parameters.pop().unwrap();
|
||||
let closure_data_ptr = parameters.pop().unwrap();
|
||||
|
||||
let closure_data =
|
||||
builder.build_load(closure_data_ptr.into_pointer_value(), "load_closure_data");
|
||||
parameters.push(closure_data);
|
||||
|
||||
let call = builder.build_call(wrapped_function, ¶meters, "call_wrapped_function");
|
||||
call.set_call_convention(FAST_CALL_CONV);
|
||||
|
||||
let result = call.try_as_basic_value().left().unwrap();
|
||||
|
||||
builder.build_store(output.into_pointer_value(), result);
|
||||
|
||||
builder.build_return(None);
|
||||
}
|
||||
|
||||
pub fn build_proc<'a, 'ctx, 'env>(
|
||||
env: &'a Env<'a, 'ctx, 'env>,
|
||||
layout_ids: &mut LayoutIds<'a>,
|
||||
|
|
|
@ -36,15 +36,20 @@ macro_rules! run_jit_function {
|
|||
($lib: expr, $main_fn_name: expr, $ty:ty, $transform:expr, $errors:expr) => {{
|
||||
use inkwell::context::Context;
|
||||
use roc_gen::run_roc::RocCallResult;
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
unsafe {
|
||||
let main: libloading::Symbol<unsafe extern "C" fn() -> RocCallResult<$ty>> = $lib
|
||||
.get($main_fn_name.as_bytes())
|
||||
.ok()
|
||||
.ok_or(format!("Unable to JIT compile `{}`", $main_fn_name))
|
||||
.expect("errored");
|
||||
let main: libloading::Symbol<unsafe extern "C" fn(*mut RocCallResult<$ty>) -> ()> =
|
||||
$lib.get($main_fn_name.as_bytes())
|
||||
.ok()
|
||||
.ok_or(format!("Unable to JIT compile `{}`", $main_fn_name))
|
||||
.expect("errored");
|
||||
|
||||
match main().into() {
|
||||
let mut result = MaybeUninit::uninit();
|
||||
|
||||
main(result.as_mut_ptr());
|
||||
|
||||
match result.assume_init().into() {
|
||||
Ok(success) => {
|
||||
// only if there are no exceptions thrown, check for errors
|
||||
assert_eq!(
|
||||
|
|
|
@ -1064,4 +1064,51 @@ mod gen_primitives {
|
|||
f64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn return_wrapped_function_pointer() {
|
||||
assert_non_opt_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
app Test provides [ main ] imports []
|
||||
|
||||
Effect a : [ @Effect ({} -> a) ]
|
||||
|
||||
foo : Effect {}
|
||||
foo = @Effect \{} -> {}
|
||||
|
||||
main : Effect {}
|
||||
main = foo
|
||||
"#
|
||||
),
|
||||
1,
|
||||
i64,
|
||||
|_| 1
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn return_wrapped_closure() {
|
||||
assert_non_opt_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
app Test provides [ main ] imports []
|
||||
|
||||
Effect a : [ @Effect ({} -> a) ]
|
||||
|
||||
foo : Effect {}
|
||||
foo =
|
||||
x = 5
|
||||
|
||||
@Effect (\{} -> if x > 3 then {} else {})
|
||||
|
||||
main : Effect {}
|
||||
main = foo
|
||||
"#
|
||||
),
|
||||
1,
|
||||
i64,
|
||||
|_| 1
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1613,7 +1613,7 @@ fn update<'a>(
|
|||
// state.timings.insert(module_id, module_timing);
|
||||
|
||||
// display the mono IR of the module, for debug purposes
|
||||
if false {
|
||||
if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS {
|
||||
let procs_string = state
|
||||
.procedures
|
||||
.values()
|
||||
|
|
|
@ -14,6 +14,8 @@ use roc_types::subs::{Content, FlatType, Subs, Variable};
|
|||
use std::collections::HashMap;
|
||||
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
|
||||
|
||||
pub const PRETTY_PRINT_IR_SYMBOLS: bool = false;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum MonoProblem {
|
||||
PatternProblem(crate::exhaustive::Error),
|
||||
|
@ -902,8 +904,11 @@ where
|
|||
D::Doc: Clone,
|
||||
A: Clone,
|
||||
{
|
||||
alloc.text(format!("{}", symbol))
|
||||
// alloc.text(format!("{:?}", symbol))
|
||||
if PRETTY_PRINT_IR_SYMBOLS {
|
||||
alloc.text(format!("{:?}", symbol))
|
||||
} else {
|
||||
alloc.text(format!("{}", symbol))
|
||||
}
|
||||
}
|
||||
|
||||
fn join_point_to_doc<'b, D, A>(alloc: &'b D, symbol: JoinPointId) -> DocBuilder<'b, D, A>
|
||||
|
@ -1514,51 +1519,9 @@ fn specialize_external<'a>(
|
|||
pattern_symbols
|
||||
};
|
||||
|
||||
let (proc_args, opt_closure_layout, ret_layout) =
|
||||
let specialized =
|
||||
build_specialized_proc_from_var(env, layout_cache, proc_name, pattern_symbols, fn_var)?;
|
||||
|
||||
let mut specialized_body = from_can(env, fn_var, body, procs, layout_cache);
|
||||
// unpack the closure symbols, if any
|
||||
if let CapturedSymbols::Captured(captured) = captured_symbols {
|
||||
let mut layouts = Vec::with_capacity_in(captured.len(), env.arena);
|
||||
|
||||
for (_, variable) in captured.iter() {
|
||||
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
|
||||
layouts.push(layout);
|
||||
}
|
||||
|
||||
let field_layouts = layouts.into_bump_slice();
|
||||
|
||||
let wrapped = match &opt_closure_layout {
|
||||
Some(x) => x.get_wrapped(),
|
||||
None => unreachable!("symbols are captured, so this must be a closure"),
|
||||
};
|
||||
|
||||
for (index, (symbol, variable)) in captured.iter().enumerate() {
|
||||
// layout is cached anyway, re-using the one found above leads to
|
||||
// issues (combining by-ref and by-move in pattern match
|
||||
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
|
||||
|
||||
// if the symbol has a layout that is dropped from data structures (e.g. `{}`)
|
||||
// then regenerate the symbol here. The value may not be present in the closure
|
||||
// data struct
|
||||
let expr = {
|
||||
if layout.is_dropped_because_empty() {
|
||||
Expr::Struct(&[])
|
||||
} else {
|
||||
Expr::AccessAtIndex {
|
||||
index: index as _,
|
||||
field_layouts,
|
||||
structure: Symbol::ARG_CLOSURE,
|
||||
wrapped,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
specialized_body = Stmt::Let(*symbol, expr, layout, env.arena.alloc(specialized_body));
|
||||
}
|
||||
}
|
||||
|
||||
// determine the layout of aliases/rigids exposed to the host
|
||||
let host_exposed_layouts = if host_exposed_variables.is_empty() {
|
||||
HostExposedLayouts::NotHostExposed
|
||||
|
@ -1578,39 +1541,141 @@ fn specialize_external<'a>(
|
|||
}
|
||||
};
|
||||
|
||||
// reset subs, so we don't get type errors when specializing for a different signature
|
||||
layout_cache.rollback_to(cache_snapshot);
|
||||
env.subs.rollback_to(snapshot);
|
||||
|
||||
let recursivity = if is_self_recursive {
|
||||
SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol()))
|
||||
} else {
|
||||
SelfRecursive::NotSelfRecursive
|
||||
};
|
||||
|
||||
let closure_data_layout = match opt_closure_layout {
|
||||
Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)),
|
||||
None => None,
|
||||
};
|
||||
let mut specialized_body = from_can(env, fn_var, body, procs, layout_cache);
|
||||
|
||||
let proc = Proc {
|
||||
name: proc_name,
|
||||
args: proc_args,
|
||||
body: specialized_body,
|
||||
closure_data_layout,
|
||||
ret_layout,
|
||||
is_self_recursive: recursivity,
|
||||
host_exposed_layouts,
|
||||
};
|
||||
match specialized {
|
||||
SpecializedLayout::FunctionPointerBody {
|
||||
arguments,
|
||||
ret_layout,
|
||||
closure: opt_closure_layout,
|
||||
} => {
|
||||
// this is a function body like
|
||||
//
|
||||
// foo = Num.add
|
||||
//
|
||||
// we need to expand this to
|
||||
//
|
||||
// foo = \x,y -> Num.add x y
|
||||
|
||||
Ok(proc)
|
||||
// reset subs, so we don't get type errors when specializing for a different signature
|
||||
layout_cache.rollback_to(cache_snapshot);
|
||||
env.subs.rollback_to(snapshot);
|
||||
|
||||
let closure_data_layout = match opt_closure_layout {
|
||||
Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
// I'm not sure how to handle the closure case, does it ever occur?
|
||||
debug_assert_eq!(closure_data_layout, None);
|
||||
debug_assert!(matches!(captured_symbols, CapturedSymbols::None));
|
||||
|
||||
// this will be a thunk returning a function, so its ret_layout must be a function!
|
||||
let full_layout = Layout::FunctionPointer(arguments, env.arena.alloc(ret_layout));
|
||||
|
||||
let proc = Proc {
|
||||
name: proc_name,
|
||||
args: &[],
|
||||
body: specialized_body,
|
||||
closure_data_layout,
|
||||
ret_layout: full_layout,
|
||||
is_self_recursive: recursivity,
|
||||
host_exposed_layouts,
|
||||
};
|
||||
|
||||
Ok(proc)
|
||||
}
|
||||
SpecializedLayout::FunctionBody {
|
||||
arguments: proc_args,
|
||||
closure: opt_closure_layout,
|
||||
ret_layout,
|
||||
} => {
|
||||
// unpack the closure symbols, if any
|
||||
if let CapturedSymbols::Captured(captured) = captured_symbols {
|
||||
let mut layouts = Vec::with_capacity_in(captured.len(), env.arena);
|
||||
|
||||
for (_, variable) in captured.iter() {
|
||||
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
|
||||
layouts.push(layout);
|
||||
}
|
||||
|
||||
let field_layouts = layouts.into_bump_slice();
|
||||
|
||||
let wrapped = match &opt_closure_layout {
|
||||
Some(x) => x.get_wrapped(),
|
||||
None => unreachable!("symbols are captured, so this must be a closure"),
|
||||
};
|
||||
|
||||
for (index, (symbol, variable)) in captured.iter().enumerate() {
|
||||
// layout is cached anyway, re-using the one found above leads to
|
||||
// issues (combining by-ref and by-move in pattern match
|
||||
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
|
||||
|
||||
// if the symbol has a layout that is dropped from data structures (e.g. `{}`)
|
||||
// then regenerate the symbol here. The value may not be present in the closure
|
||||
// data struct
|
||||
let expr = {
|
||||
if layout.is_dropped_because_empty() {
|
||||
Expr::Struct(&[])
|
||||
} else {
|
||||
Expr::AccessAtIndex {
|
||||
index: index as _,
|
||||
field_layouts,
|
||||
structure: Symbol::ARG_CLOSURE,
|
||||
wrapped,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
specialized_body =
|
||||
Stmt::Let(*symbol, expr, layout, env.arena.alloc(specialized_body));
|
||||
}
|
||||
}
|
||||
|
||||
// reset subs, so we don't get type errors when specializing for a different signature
|
||||
layout_cache.rollback_to(cache_snapshot);
|
||||
env.subs.rollback_to(snapshot);
|
||||
|
||||
let closure_data_layout = match opt_closure_layout {
|
||||
Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let proc = Proc {
|
||||
name: proc_name,
|
||||
args: proc_args,
|
||||
body: specialized_body,
|
||||
closure_data_layout,
|
||||
ret_layout,
|
||||
is_self_recursive: recursivity,
|
||||
host_exposed_layouts,
|
||||
};
|
||||
|
||||
Ok(proc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type SpecializedLayout<'a> = (
|
||||
&'a [(Layout<'a>, Symbol)],
|
||||
Option<ClosureLayout<'a>>,
|
||||
Layout<'a>,
|
||||
);
|
||||
enum SpecializedLayout<'a> {
|
||||
/// A body like `foo = \a,b,c -> ...`
|
||||
FunctionBody {
|
||||
arguments: &'a [(Layout<'a>, Symbol)],
|
||||
closure: Option<ClosureLayout<'a>>,
|
||||
ret_layout: Layout<'a>,
|
||||
},
|
||||
/// A body like `foo = Num.add`
|
||||
FunctionPointerBody {
|
||||
arguments: &'a [Layout<'a>],
|
||||
closure: Option<ClosureLayout<'a>>,
|
||||
ret_layout: Layout<'a>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn build_specialized_proc_from_var<'a>(
|
||||
|
@ -1739,6 +1804,7 @@ fn build_specialized_proc<'a>(
|
|||
let mut proc_args = Vec::with_capacity_in(pattern_layouts.len(), arena);
|
||||
|
||||
let pattern_layouts_len = pattern_layouts.len();
|
||||
let pattern_layouts_slice = pattern_layouts.clone().into_bump_slice();
|
||||
|
||||
for (arg_layout, arg_name) in pattern_layouts.into_iter().zip(pattern_symbols.iter()) {
|
||||
proc_args.push((arg_layout, *arg_name));
|
||||
|
@ -1761,6 +1827,7 @@ fn build_specialized_proc<'a>(
|
|||
// f_closure = { ptr: f, closure: x }
|
||||
//
|
||||
// then
|
||||
use SpecializedLayout::*;
|
||||
match opt_closure_layout {
|
||||
Some(layout) if pattern_symbols.last() == Some(&Symbol::ARG_CLOSURE) => {
|
||||
// here we define the lifted (now top-level) f function. Its final argument is `Symbol::ARG_CLOSURE`,
|
||||
|
@ -1776,7 +1843,11 @@ fn build_specialized_proc<'a>(
|
|||
|
||||
let proc_args = proc_args.into_bump_slice();
|
||||
|
||||
Ok((proc_args, Some(layout), ret_layout))
|
||||
Ok(FunctionBody {
|
||||
arguments: proc_args,
|
||||
closure: Some(layout),
|
||||
ret_layout,
|
||||
})
|
||||
}
|
||||
Some(layout) => {
|
||||
// else if there is a closure layout, we're building the `f_closure` value
|
||||
|
@ -1791,7 +1862,11 @@ fn build_specialized_proc<'a>(
|
|||
let closure_layout =
|
||||
Layout::Struct(arena.alloc([function_ptr_layout, closure_data_layout]));
|
||||
|
||||
Ok((&[], None, closure_layout))
|
||||
Ok(FunctionBody {
|
||||
arguments: &[],
|
||||
closure: None,
|
||||
ret_layout: closure_layout,
|
||||
})
|
||||
}
|
||||
None => {
|
||||
// else we're making a normal function, no closure problems to worry about
|
||||
|
@ -1805,16 +1880,32 @@ fn build_specialized_proc<'a>(
|
|||
Ordering::Equal => {
|
||||
let proc_args = proc_args.into_bump_slice();
|
||||
|
||||
Ok((proc_args, None, ret_layout))
|
||||
Ok(FunctionBody {
|
||||
arguments: proc_args,
|
||||
closure: None,
|
||||
ret_layout,
|
||||
})
|
||||
}
|
||||
Ordering::Greater => {
|
||||
// so far, the problem when hitting this branch was always somewhere else
|
||||
// I think this branch should not be reachable in a bugfree compiler
|
||||
panic!("more arguments (according to the layout) than argument symbols")
|
||||
}
|
||||
Ordering::Less => {
|
||||
panic!("more argument symbols than arguments (according to the layout)")
|
||||
if pattern_symbols.is_empty() {
|
||||
Ok(FunctionPointerBody {
|
||||
arguments: pattern_layouts_slice,
|
||||
closure: None,
|
||||
ret_layout,
|
||||
})
|
||||
} else {
|
||||
// so far, the problem when hitting this branch was always somewhere else
|
||||
// I think this branch should not be reachable in a bugfree compiler
|
||||
panic!(
|
||||
"more arguments (according to the layout) than argument symbols for {:?}",
|
||||
proc_name
|
||||
)
|
||||
}
|
||||
}
|
||||
Ordering::Less => panic!(
|
||||
"more argument symbols than arguments (according to the layout) for {:?}",
|
||||
proc_name
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2213,8 +2304,22 @@ pub fn with_hole<'a>(
|
|||
} else if symbol.module_id() != env.home && symbol.module_id() != ModuleId::ATTR {
|
||||
match layout_cache.from_var(env.arena, variable, env.subs) {
|
||||
Err(e) => panic!("invalid layout {:?}", e),
|
||||
Ok(Layout::FunctionPointer(_, _)) => {
|
||||
Ok(layout @ Layout::FunctionPointer(_, _)) => {
|
||||
add_needed_external(procs, env, variable, symbol);
|
||||
|
||||
match hole {
|
||||
Stmt::Jump(_, _) => todo!("not sure what to do in this case yet"),
|
||||
_ => {
|
||||
let expr = Expr::FunctionPointer(symbol, layout.clone());
|
||||
let new_symbol = env.unique_symbol();
|
||||
return Stmt::Let(
|
||||
new_symbol,
|
||||
expr,
|
||||
layout,
|
||||
env.arena.alloc(Stmt::Ret(new_symbol)),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_) => {
|
||||
// this is a 0-arity thunk
|
||||
|
@ -4678,6 +4783,12 @@ fn call_by_name<'a>(
|
|||
.specialized
|
||||
.contains_key(&(proc_name, full_layout.clone()))
|
||||
{
|
||||
debug_assert_eq!(
|
||||
arg_layouts.len(),
|
||||
field_symbols.len(),
|
||||
"see call_by_name for background (scroll down a bit)"
|
||||
);
|
||||
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: ret_layout.clone(),
|
||||
|
@ -4721,6 +4832,11 @@ fn call_by_name<'a>(
|
|||
);
|
||||
}
|
||||
|
||||
debug_assert_eq!(
|
||||
arg_layouts.len(),
|
||||
field_symbols.len(),
|
||||
"see call_by_name for background (scroll down a bit)"
|
||||
);
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: ret_layout.clone(),
|
||||
|
@ -4762,30 +4878,102 @@ fn call_by_name<'a>(
|
|||
let function_layout =
|
||||
FunctionLayouts::from_layout(env.arena, layout);
|
||||
|
||||
procs.specialized.remove(&(proc_name, full_layout));
|
||||
procs.specialized.remove(&(proc_name, full_layout.clone()));
|
||||
|
||||
procs.specialized.insert(
|
||||
(proc_name, function_layout.full.clone()),
|
||||
Done(proc),
|
||||
);
|
||||
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: function_layout.result.clone(),
|
||||
full_layout: function_layout.full,
|
||||
arg_layouts: function_layout.arguments,
|
||||
args: field_symbols,
|
||||
};
|
||||
if field_symbols.is_empty() {
|
||||
debug_assert!(loc_args.is_empty());
|
||||
|
||||
let iter = loc_args
|
||||
.into_iter()
|
||||
.rev()
|
||||
.zip(field_symbols.iter().rev());
|
||||
// This happens when we return a function, e.g.
|
||||
//
|
||||
// foo = Num.add
|
||||
//
|
||||
// Even though the layout (and type) are functions,
|
||||
// there are no arguments. This confuses our IR,
|
||||
// and we have to fix it here.
|
||||
match full_layout {
|
||||
Layout::Closure(_, closure_layout, _) => {
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: function_layout.result.clone(),
|
||||
full_layout: function_layout.full.clone(),
|
||||
arg_layouts: function_layout.arguments,
|
||||
args: field_symbols,
|
||||
};
|
||||
|
||||
let result =
|
||||
Stmt::Let(assigned, call, function_layout.result, hole);
|
||||
// in the case of a closure specifically, we
|
||||
// have to create a custom layout, to make sure
|
||||
// the closure data is part of the layout
|
||||
let closure_struct_layout = Layout::Struct(
|
||||
env.arena.alloc([
|
||||
function_layout.full,
|
||||
closure_layout
|
||||
.as_block_of_memory_layout(),
|
||||
]),
|
||||
);
|
||||
|
||||
assign_to_symbols(env, procs, layout_cache, iter, result)
|
||||
Stmt::Let(
|
||||
assigned,
|
||||
call,
|
||||
closure_struct_layout,
|
||||
hole,
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: function_layout.result.clone(),
|
||||
full_layout: function_layout.full.clone(),
|
||||
arg_layouts: function_layout.arguments,
|
||||
args: field_symbols,
|
||||
};
|
||||
|
||||
Stmt::Let(
|
||||
assigned,
|
||||
call,
|
||||
function_layout.full,
|
||||
hole,
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug_assert_eq!(
|
||||
function_layout.arguments.len(),
|
||||
field_symbols.len(),
|
||||
"scroll up a bit for background"
|
||||
);
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: function_layout.result.clone(),
|
||||
full_layout: function_layout.full,
|
||||
arg_layouts: function_layout.arguments,
|
||||
args: field_symbols,
|
||||
};
|
||||
|
||||
let iter = loc_args
|
||||
.into_iter()
|
||||
.rev()
|
||||
.zip(field_symbols.iter().rev());
|
||||
|
||||
let result = Stmt::Let(
|
||||
assigned,
|
||||
call,
|
||||
function_layout.result,
|
||||
hole,
|
||||
);
|
||||
|
||||
assign_to_symbols(
|
||||
env,
|
||||
procs,
|
||||
layout_cache,
|
||||
iter,
|
||||
result,
|
||||
)
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
let error_msg = env.arena.alloc(format!(
|
||||
|
@ -4804,6 +4992,12 @@ fn call_by_name<'a>(
|
|||
None if assigned.module_id() != proc_name.module_id() => {
|
||||
add_needed_external(procs, env, original_fn_var, proc_name);
|
||||
|
||||
debug_assert_eq!(
|
||||
arg_layouts.len(),
|
||||
field_symbols.len(),
|
||||
"scroll up a bit for background"
|
||||
);
|
||||
|
||||
let call = Expr::FunctionCall {
|
||||
call_type: CallType::ByName(proc_name),
|
||||
ret_layout: ret_layout.clone(),
|
||||
|
|
|
@ -1422,27 +1422,27 @@ mod test_mono {
|
|||
indoc!(
|
||||
r#"
|
||||
procedure Num.15 (#Attr.2, #Attr.3):
|
||||
let Test.13 = lowlevel NumSub #Attr.2 #Attr.3;
|
||||
ret Test.13;
|
||||
let Test.14 = lowlevel NumSub #Attr.2 #Attr.3;
|
||||
ret Test.14;
|
||||
|
||||
procedure Num.16 (#Attr.2, #Attr.3):
|
||||
let Test.11 = lowlevel NumMul #Attr.2 #Attr.3;
|
||||
ret Test.11;
|
||||
let Test.12 = lowlevel NumMul #Attr.2 #Attr.3;
|
||||
ret Test.12;
|
||||
|
||||
procedure Test.1 (Test.2, Test.3):
|
||||
jump Test.18 Test.2 Test.3;
|
||||
joinpoint Test.18 Test.2 Test.3:
|
||||
let Test.15 = true;
|
||||
let Test.16 = 0i64;
|
||||
let Test.17 = lowlevel Eq Test.16 Test.2;
|
||||
let Test.14 = lowlevel And Test.17 Test.15;
|
||||
if Test.14 then
|
||||
jump Test.7 Test.2 Test.3;
|
||||
joinpoint Test.7 Test.2 Test.3:
|
||||
let Test.16 = true;
|
||||
let Test.17 = 0i64;
|
||||
let Test.18 = lowlevel Eq Test.17 Test.2;
|
||||
let Test.15 = lowlevel And Test.18 Test.16;
|
||||
if Test.15 then
|
||||
ret Test.3;
|
||||
else
|
||||
let Test.12 = 1i64;
|
||||
let Test.9 = CallByName Num.15 Test.2 Test.12;
|
||||
let Test.10 = CallByName Num.16 Test.2 Test.3;
|
||||
jump Test.18 Test.9 Test.10;
|
||||
let Test.13 = 1i64;
|
||||
let Test.10 = CallByName Num.15 Test.2 Test.13;
|
||||
let Test.11 = CallByName Num.16 Test.2 Test.3;
|
||||
jump Test.7 Test.10 Test.11;
|
||||
|
||||
procedure Test.0 ():
|
||||
let Test.5 = 10i64;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue