Call erased functions

This commit is contained in:
Ayaz Hafiz 2023-06-25 18:10:51 -05:00
parent 558d7459b4
commit 510a421748
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
10 changed files with 350 additions and 23 deletions

View file

@ -9,6 +9,7 @@ use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_collections::all::{MutMap, MutSet};
use roc_collections::ReferenceMatrix;
use roc_error_macros::todo_lambda_erasure;
use roc_module::low_level::LowLevel;
use roc_module::symbol::Symbol;
@ -560,6 +561,10 @@ impl<'a> BorrowInfState<'a> {
self.own_args_using_params(arguments, ps);
}
ByPointer { .. } => {
todo_lambda_erasure!()
}
LowLevel { op, .. } => {
debug_assert!(!op.is_higher_order());
@ -1056,6 +1061,9 @@ fn call_info_call<'a>(call: &crate::ir::Call<'a>, info: &mut CallInfo<'a>) {
ByName { name, .. } => {
info.keys.push(name.name());
}
ByPointer { .. } => {
todo_lambda_erasure!()
}
Foreign { .. } => {}
LowLevel { .. } => {}
HigherOrder(HigherOrderLowLevel {

View file

@ -10,8 +10,8 @@ use crate::{
ModifyRc, Param, Proc, ProcLayout, Stmt,
},
layout::{
Builtin, InLayout, Layout, LayoutInterner, LayoutRepr, STLayoutInterner, TagIdIntType,
UnionLayout,
Builtin, FunctionPointer, InLayout, Layout, LayoutInterner, LayoutRepr, STLayoutInterner,
TagIdIntType, UnionLayout,
},
};
@ -642,6 +642,23 @@ impl<'a, 'r> Ctx<'a, 'r> {
}
Some(*ret_layout)
}
CallType::ByPointer {
pointer,
ret_layout,
arg_layouts,
} => {
let expected_layout =
self.interner
.insert_direct_no_semantic(LayoutRepr::FunctionPointer(FunctionPointer {
args: arg_layouts,
ret: *ret_layout,
}));
self.check_sym_layout(*pointer, expected_layout, UseKind::SwitchCond);
for (arg, wanted_layout) in arguments.iter().zip(arg_layouts.iter()) {
self.check_sym_layout(*arg, *wanted_layout, UseKind::CallArg);
}
Some(*ret_layout)
}
CallType::HigherOrder(HigherOrderLowLevel {
op: _,
closure_env_layout: _,

View file

@ -10,7 +10,7 @@ use std::{collections::HashMap, hash::BuildHasherDefault};
use bumpalo::collections::{CollectIn, Vec};
use bumpalo::Bump;
use roc_collections::{all::WyHash, MutMap, MutSet};
use roc_error_macros::internal_error;
use roc_error_macros::{internal_error, todo_lambda_erasure};
use roc_module::low_level::LowLevel;
use roc_module::{low_level::LowLevelWrapperType, symbol::Symbol};
@ -942,6 +942,9 @@ fn insert_refcount_operations_binding<'a>(
inc_owned!(arguments.iter().copied(), new_let)
}
CallType::ByPointer { .. } => {
todo_lambda_erasure!()
}
CallType::Foreign { .. } => {
// Foreign functions should be responsible for their own memory management.
// But previously they were assumed to be called with borrowed parameters, so we do the same now.

View file

@ -5452,25 +5452,21 @@ pub fn with_hole<'a>(
env.arena.alloc(result),
);
}
RawFunctionLayout::ErasedFunction(..) => {
// What we want here is
// f = compile(loc_expr)
// joinpoint join result:
// <hole>
// if (f.value) {
// f = cast(f, (..params, void*) -> ret);
// result = f ..args
// jump join result
// } else {
// f = cast(f, (..params) -> ret);
// result = f ..args
// jump join result
// }
todo_lambda_erasure!(
"{:?} :: {:?}",
RawFunctionLayout::ErasedFunction(arg_layouts, ret_layout) => {
let hole_layout =
layout_cache.from_var(env.arena, fn_var, env.subs).unwrap();
result = erased::call_erased_function(
env,
layout_cache,
procs,
loc_expr.value,
full_layout
)
fn_var,
(arg_layouts, ret_layout),
arg_symbols,
assigned,
hole,
hole_layout,
);
}
RawFunctionLayout::ZeroArgumentThunk(_) => {
unreachable!(

View file

@ -0,0 +1,286 @@
use bumpalo::{collections::Vec as AVec, Bump};
use roc_module::{low_level::LowLevel, symbol::Symbol};
use roc_types::subs::Variable;
use crate::{
borrow::Ownership,
layout::{FunctionPointer, InLayout, Layout, LayoutCache, LayoutRepr},
};
use super::{
with_hole, BranchInfo, Call, CallType, Env, Expr, JoinPointId, Param, Procs, Stmt, UpdateModeId,
};
const ERASED_FUNCTION_FIELD_LAYOUTS: &[InLayout] =
&[Layout::OPAQUE_PTR, Layout::OPAQUE_PTR, Layout::OPAQUE_PTR];
/// The layout of an erased function is
///
/// ```
/// {
/// value: void*,
/// callee: void*,
/// refcounter: void*,
/// }
/// ```
fn erased_function_layout<'a>(layout_cache: &mut LayoutCache<'a>) -> InLayout<'a> {
layout_cache.put_in_direct_no_semantic(LayoutRepr::Struct(ERASED_FUNCTION_FIELD_LAYOUTS))
}
#[repr(u8)]
enum ErasedFunctionIndex {
Value = 0,
Callee = 1,
RefCounter = 2,
}
fn index_erased_function<'a>(
arena: &'a Bump,
assign_to: Symbol,
erased_function: Symbol,
index: ErasedFunctionIndex,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
move |rest| {
Stmt::Let(
assign_to,
Expr::StructAtIndex {
index: index as _,
structure: erased_function,
field_layouts: ERASED_FUNCTION_FIELD_LAYOUTS,
},
Layout::OPAQUE_PTR,
arena.alloc(rest),
)
}
}
fn cast_erased_callee<'a>(
arena: &'a Bump,
assign_to: Symbol,
erased_function: Symbol,
fn_ptr_layout: InLayout<'a>,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
move |rest| {
Stmt::Let(
assign_to,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: arena.alloc([erased_function]),
}),
fn_ptr_layout,
arena.alloc(rest),
)
}
}
fn call_callee<'a>(
arena: &'a Bump,
result_symbol: Symbol,
result: InLayout<'a>,
fn_ptr_symbol: Symbol,
fn_arg_layouts: &'a [InLayout<'a>],
fn_arguments: &'a [Symbol],
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
move |rest| {
Stmt::Let(
result_symbol,
Expr::Call(Call {
call_type: CallType::ByPointer {
pointer: fn_ptr_symbol,
ret_layout: result,
arg_layouts: fn_arg_layouts,
},
arguments: fn_arguments,
}),
result,
arena.alloc(rest),
)
}
}
fn is_null<'a>(
env: &mut Env<'a, '_>,
arena: &'a Bump,
assign_to: Symbol,
ptr_symbol: Symbol,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
let null_symbol = env.unique_symbol();
move |rest| {
Stmt::Let(
null_symbol,
Expr::NullPointer,
Layout::OPAQUE_PTR,
arena.alloc(Stmt::Let(
assign_to,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::Eq,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: arena.alloc([ptr_symbol, null_symbol]),
}),
Layout::BOOL,
arena.alloc(rest),
)),
)
}
}
/// Given
///
/// ```
/// Call(f, args)
/// ```
///
/// We generate
///
/// ```
/// f = compile(f)
/// joinpoint join result:
/// <hole>
/// if (f.value) {
/// f = cast(f.callee, (..params, void*) -> ret);
/// result = f ..args f.value
/// jump join result
/// } else {
/// f = cast(f.callee, (..params) -> ret);
/// result = f ..args
/// jump join result
/// }
/// ```
pub fn call_erased_function<'a>(
env: &mut Env<'a, '_>,
layout_cache: &mut LayoutCache<'a>,
procs: &mut Procs<'a>,
function_expr: roc_can::expr::Expr,
function_var: Variable,
function_signature: (&'a [InLayout<'a>], InLayout<'a>),
function_argument_symbols: &'a [Symbol],
call_result_symbol: Symbol,
hole: &'a Stmt<'a>,
hole_layout: InLayout<'a>,
) -> Stmt<'a> {
let arena = env.arena;
let (f_args, f_ret) = function_signature;
let f = env.unique_symbol();
let join_point_id = JoinPointId(env.unique_symbol());
// f_value = f.value
let f_value = env.unique_symbol();
let let_f_value = index_erased_function(arena, f_value, f, ErasedFunctionIndex::Value);
// f_callee = f.callee
let f_callee = env.unique_symbol();
let let_f_callee = index_erased_function(arena, f_callee, f, ErasedFunctionIndex::Callee);
let mut build_closure_data_branch = |env: &mut Env, pass_closure| {
// f_callee = cast(f_callee, (..params) -> ret);
// result = f_callee ..args
// jump join result
let (f_args, function_argument_symbols) = if pass_closure {
// f_args = ...args, f.value
// function_argument_symbols = ...args, f.value
let f_args = {
let mut args = AVec::with_capacity_in(f_args.len() + 1, arena);
args.extend(f_args.iter().chain(&[Layout::OPAQUE_PTR]).copied());
args.into_bump_slice()
};
let function_argument_symbols = {
let mut args = AVec::with_capacity_in(function_argument_symbols.len() + 1, arena);
args.extend(function_argument_symbols.iter().chain(&[f_value]));
args.into_bump_slice()
};
(f_args, function_argument_symbols)
} else {
(f_args, function_argument_symbols)
};
let fn_ptr_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::FunctionPointer(FunctionPointer {
args: f_args,
ret: f_ret,
}));
let f_callee_cast = env.unique_symbol();
let let_f_callee_cast = cast_erased_callee(arena, f_callee_cast, f_callee, fn_ptr_layout);
let result = env.unique_symbol();
let let_result = call_callee(
arena,
result,
f_ret,
f_callee_cast,
f_args,
function_argument_symbols,
);
let_f_callee_cast(
//
let_result(
//
Stmt::Jump(join_point_id, arena.alloc([result])),
),
)
};
let value_is_null = env.unique_symbol();
let let_value_is_null = is_null(env, arena, value_is_null, f_value);
let call_and_jump_on_value = let_value_is_null(
//
Stmt::Switch {
cond_symbol: value_is_null,
cond_layout: Layout::BOOL,
// value == null
branches: arena.alloc([(0, BranchInfo::None, build_closure_data_branch(env, false))]),
// value != null
default_branch: (
BranchInfo::None,
arena.alloc(build_closure_data_branch(env, true)),
),
ret_layout: hole_layout,
},
);
let joinpoint = {
let param = Param {
symbol: call_result_symbol,
layout: f_ret,
ownership: Ownership::Owned,
};
let remainder =
// f_value = f.value
let_f_value(
// f_callee = f.callee
let_f_callee(
//
call_and_jump_on_value,
),
);
Stmt::Join {
id: join_point_id,
parameters: env.arena.alloc([param]),
body: hole,
remainder: arena.alloc(remainder),
}
};
// Compile the function expression into f_val
with_hole(
env,
function_expr,
function_var,
procs,
layout_cache,
f,
env.arena.alloc(joinpoint),
)
}

View file

@ -670,6 +670,7 @@ impl<'a> TrmcEnv<'a> {
// because we do not allow polymorphic recursion, this is the only constraint
name == lambda_name
}
CallType::ByPointer { .. } => false,
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
false
}