Set.walk closure trouble

This commit is contained in:
Folkert 2021-02-15 02:04:04 +01:00
parent 35e1e94a94
commit 9527434be8
3 changed files with 111 additions and 22 deletions

View file

@ -1850,6 +1850,7 @@ fn invoke_roc_function<'a, 'ctx, 'env>(
layout: Layout<'a>,
function_value: Either<FunctionValue<'ctx>, PointerValue<'ctx>>,
arguments: &[Symbol],
closure_argument: Option<BasicValueEnum<'ctx>>,
pass: &'a roc_mono::ir::Stmt<'a>,
fail: &'a roc_mono::ir::Stmt<'a>,
) -> BasicValueEnum<'ctx> {
@ -1860,6 +1861,7 @@ fn invoke_roc_function<'a, 'ctx, 'env>(
for arg in arguments.iter() {
arg_vals.push(load_symbol(scope, arg));
}
arg_vals.extend(closure_argument);
let pass_block = context.append_basic_block(parent, "invoke_pass");
let fail_block = context.append_basic_block(parent, "invoke_fail");
@ -2019,6 +2021,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
layout.clone(),
function_value.into(),
call.arguments,
None,
pass,
fail,
)
@ -2026,15 +2029,35 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
CallType::ByPointer { name, .. } => {
let sub_expr = load_symbol(scope, &name);
let function_ptr = match sub_expr {
BasicValueEnum::PointerValue(ptr) => ptr,
non_ptr => {
panic!(
"Tried to call by pointer, but encountered a non-pointer: {:?}",
non_ptr
);
match sub_expr {
BasicValueEnum::PointerValue(function_ptr) => {
// basic call by pointer
invoke_roc_function(
env,
layout_ids,
scope,
parent,
*symbol,
layout.clone(),
function_ptr.into(),
call.arguments,
None,
pass,
fail,
)
}
};
BasicValueEnum::StructValue(ptr_and_data) => {
// this is a closure
let builder = env.builder;
let function_ptr = builder
.build_extract_value(ptr_and_data, 0, "function_ptr")
.unwrap()
.into_pointer_value();
let closure_data = builder
.build_extract_value(ptr_and_data, 1, "closure_data")
.unwrap();
invoke_roc_function(
env,
@ -2045,10 +2068,19 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
layout.clone(),
function_ptr.into(),
call.arguments,
Some(closure_data),
pass,
fail,
)
}
non_ptr => {
panic!(
"Tried to call by pointer, but encountered a non-pointer: {:?}",
non_ptr
);
}
}
}
CallType::Foreign {
ref foreign_symbol,
ref ret_layout,

View file

@ -921,6 +921,63 @@ pub fn list_walk<'a, 'ctx, 'env>(
);
}
(
BasicValueEnum::StructValue(ptr_and_data),
Layout::Closure(_, closure_layout, ret_elem_layout),
) => {
let builder = env.builder;
let func_ptr = builder
.build_extract_value(ptr_and_data, 0, "function_ptr")
.unwrap()
.into_pointer_value();
let closure_data = builder
.build_extract_value(ptr_and_data, 1, "closure_data")
.unwrap();
let elem_layout = match list_layout {
Layout::Builtin(Builtin::List(_, layout)) => layout,
_ => unreachable!("can only fold over a list"),
};
let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes);
let elem_ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic);
let list_ptr = load_list_ptr(builder, list_wrapper, elem_ptr_type);
let walk_right_loop = |_, elem: BasicValueEnum<'ctx>| {
// load current accumulator
let current = builder.build_load(accum_alloca, "retrieve_accum");
let call_site_value = builder.build_call(
func_ptr,
&[elem, current, closure_data],
"#walk_right_func",
);
// set the calling convention explicitly for this call
call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV);
let new_current = call_site_value
.try_as_basic_value()
.left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."));
builder.build_store(accum_alloca, new_current);
};
incrementing_elem_loop(
builder,
ctx,
parent,
list_ptr,
len,
"#index",
walk_right_loop,
);
}
_ => {
unreachable!(
"Invalid function basic value enum or layout for List.keepIf : {:?}",

View file

@ -5933,12 +5933,6 @@ fn call_by_name<'a>(
None if assigned.module_id() != proc_name.module_id() => {
add_needed_external(procs, env, original_fn_var, proc_name);
debug_assert_eq!(
arg_layouts.len(),
field_symbols.len(),
"scroll up a bit for background"
);
let call = if proc_name.module_id() == ModuleId::ATTR {
// the callable is one of the ATTR::ARG_n symbols
// we must call those by-pointer
@ -5952,6 +5946,12 @@ fn call_by_name<'a>(
arguments: field_symbols,
}
} else {
debug_assert_eq!(
arg_layouts.len(),
field_symbols.len(),
"scroll up a bit for background {:?}",
proc_name
);
self::Call {
call_type: CallType::ByName {
name: proc_name,