mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 13:29:12 +00:00
Call erased functions
This commit is contained in:
parent
558d7459b4
commit
510a421748
10 changed files with 350 additions and 23 deletions
|
@ -9,6 +9,7 @@ use bumpalo::collections::Vec;
|
|||
use bumpalo::Bump;
|
||||
use roc_collections::all::{MutMap, MutSet};
|
||||
use roc_collections::ReferenceMatrix;
|
||||
use roc_error_macros::todo_lambda_erasure;
|
||||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::symbol::Symbol;
|
||||
|
||||
|
@ -560,6 +561,10 @@ impl<'a> BorrowInfState<'a> {
|
|||
self.own_args_using_params(arguments, ps);
|
||||
}
|
||||
|
||||
ByPointer { .. } => {
|
||||
todo_lambda_erasure!()
|
||||
}
|
||||
|
||||
LowLevel { op, .. } => {
|
||||
debug_assert!(!op.is_higher_order());
|
||||
|
||||
|
@ -1056,6 +1061,9 @@ fn call_info_call<'a>(call: &crate::ir::Call<'a>, info: &mut CallInfo<'a>) {
|
|||
ByName { name, .. } => {
|
||||
info.keys.push(name.name());
|
||||
}
|
||||
ByPointer { .. } => {
|
||||
todo_lambda_erasure!()
|
||||
}
|
||||
Foreign { .. } => {}
|
||||
LowLevel { .. } => {}
|
||||
HigherOrder(HigherOrderLowLevel {
|
||||
|
|
|
@ -10,8 +10,8 @@ use crate::{
|
|||
ModifyRc, Param, Proc, ProcLayout, Stmt,
|
||||
},
|
||||
layout::{
|
||||
Builtin, InLayout, Layout, LayoutInterner, LayoutRepr, STLayoutInterner, TagIdIntType,
|
||||
UnionLayout,
|
||||
Builtin, FunctionPointer, InLayout, Layout, LayoutInterner, LayoutRepr, STLayoutInterner,
|
||||
TagIdIntType, UnionLayout,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -642,6 +642,23 @@ impl<'a, 'r> Ctx<'a, 'r> {
|
|||
}
|
||||
Some(*ret_layout)
|
||||
}
|
||||
CallType::ByPointer {
|
||||
pointer,
|
||||
ret_layout,
|
||||
arg_layouts,
|
||||
} => {
|
||||
let expected_layout =
|
||||
self.interner
|
||||
.insert_direct_no_semantic(LayoutRepr::FunctionPointer(FunctionPointer {
|
||||
args: arg_layouts,
|
||||
ret: *ret_layout,
|
||||
}));
|
||||
self.check_sym_layout(*pointer, expected_layout, UseKind::SwitchCond);
|
||||
for (arg, wanted_layout) in arguments.iter().zip(arg_layouts.iter()) {
|
||||
self.check_sym_layout(*arg, *wanted_layout, UseKind::CallArg);
|
||||
}
|
||||
Some(*ret_layout)
|
||||
}
|
||||
CallType::HigherOrder(HigherOrderLowLevel {
|
||||
op: _,
|
||||
closure_env_layout: _,
|
||||
|
|
|
@ -10,7 +10,7 @@ use std::{collections::HashMap, hash::BuildHasherDefault};
|
|||
use bumpalo::collections::{CollectIn, Vec};
|
||||
use bumpalo::Bump;
|
||||
use roc_collections::{all::WyHash, MutMap, MutSet};
|
||||
use roc_error_macros::internal_error;
|
||||
use roc_error_macros::{internal_error, todo_lambda_erasure};
|
||||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::{low_level::LowLevelWrapperType, symbol::Symbol};
|
||||
|
||||
|
@ -942,6 +942,9 @@ fn insert_refcount_operations_binding<'a>(
|
|||
|
||||
inc_owned!(arguments.iter().copied(), new_let)
|
||||
}
|
||||
CallType::ByPointer { .. } => {
|
||||
todo_lambda_erasure!()
|
||||
}
|
||||
CallType::Foreign { .. } => {
|
||||
// Foreign functions should be responsible for their own memory management.
|
||||
// But previously they were assumed to be called with borrowed parameters, so we do the same now.
|
||||
|
|
|
@ -5452,25 +5452,21 @@ pub fn with_hole<'a>(
|
|||
env.arena.alloc(result),
|
||||
);
|
||||
}
|
||||
RawFunctionLayout::ErasedFunction(..) => {
|
||||
// What we want here is
|
||||
// f = compile(loc_expr)
|
||||
// joinpoint join result:
|
||||
// <hole>
|
||||
// if (f.value) {
|
||||
// f = cast(f, (..params, void*) -> ret);
|
||||
// result = f ..args
|
||||
// jump join result
|
||||
// } else {
|
||||
// f = cast(f, (..params) -> ret);
|
||||
// result = f ..args
|
||||
// jump join result
|
||||
// }
|
||||
todo_lambda_erasure!(
|
||||
"{:?} :: {:?}",
|
||||
RawFunctionLayout::ErasedFunction(arg_layouts, ret_layout) => {
|
||||
let hole_layout =
|
||||
layout_cache.from_var(env.arena, fn_var, env.subs).unwrap();
|
||||
result = erased::call_erased_function(
|
||||
env,
|
||||
layout_cache,
|
||||
procs,
|
||||
loc_expr.value,
|
||||
full_layout
|
||||
)
|
||||
fn_var,
|
||||
(arg_layouts, ret_layout),
|
||||
arg_symbols,
|
||||
assigned,
|
||||
hole,
|
||||
hole_layout,
|
||||
);
|
||||
}
|
||||
RawFunctionLayout::ZeroArgumentThunk(_) => {
|
||||
unreachable!(
|
||||
|
|
286
crates/compiler/mono/src/ir/erased.rs
Normal file
286
crates/compiler/mono/src/ir/erased.rs
Normal file
|
@ -0,0 +1,286 @@
|
|||
use bumpalo::{collections::Vec as AVec, Bump};
|
||||
use roc_module::{low_level::LowLevel, symbol::Symbol};
|
||||
use roc_types::subs::Variable;
|
||||
|
||||
use crate::{
|
||||
borrow::Ownership,
|
||||
layout::{FunctionPointer, InLayout, Layout, LayoutCache, LayoutRepr},
|
||||
};
|
||||
|
||||
use super::{
|
||||
with_hole, BranchInfo, Call, CallType, Env, Expr, JoinPointId, Param, Procs, Stmt, UpdateModeId,
|
||||
};
|
||||
|
||||
const ERASED_FUNCTION_FIELD_LAYOUTS: &[InLayout] =
|
||||
&[Layout::OPAQUE_PTR, Layout::OPAQUE_PTR, Layout::OPAQUE_PTR];
|
||||
|
||||
/// The layout of an erased function is
|
||||
///
|
||||
/// ```
|
||||
/// {
|
||||
/// value: void*,
|
||||
/// callee: void*,
|
||||
/// refcounter: void*,
|
||||
/// }
|
||||
/// ```
|
||||
fn erased_function_layout<'a>(layout_cache: &mut LayoutCache<'a>) -> InLayout<'a> {
|
||||
layout_cache.put_in_direct_no_semantic(LayoutRepr::Struct(ERASED_FUNCTION_FIELD_LAYOUTS))
|
||||
}
|
||||
|
||||
#[repr(u8)]
|
||||
enum ErasedFunctionIndex {
|
||||
Value = 0,
|
||||
Callee = 1,
|
||||
RefCounter = 2,
|
||||
}
|
||||
|
||||
fn index_erased_function<'a>(
|
||||
arena: &'a Bump,
|
||||
assign_to: Symbol,
|
||||
erased_function: Symbol,
|
||||
index: ErasedFunctionIndex,
|
||||
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
|
||||
move |rest| {
|
||||
Stmt::Let(
|
||||
assign_to,
|
||||
Expr::StructAtIndex {
|
||||
index: index as _,
|
||||
structure: erased_function,
|
||||
field_layouts: ERASED_FUNCTION_FIELD_LAYOUTS,
|
||||
},
|
||||
Layout::OPAQUE_PTR,
|
||||
arena.alloc(rest),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn cast_erased_callee<'a>(
|
||||
arena: &'a Bump,
|
||||
assign_to: Symbol,
|
||||
erased_function: Symbol,
|
||||
fn_ptr_layout: InLayout<'a>,
|
||||
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
|
||||
move |rest| {
|
||||
Stmt::Let(
|
||||
assign_to,
|
||||
Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::PtrCast,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: arena.alloc([erased_function]),
|
||||
}),
|
||||
fn_ptr_layout,
|
||||
arena.alloc(rest),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn call_callee<'a>(
|
||||
arena: &'a Bump,
|
||||
result_symbol: Symbol,
|
||||
result: InLayout<'a>,
|
||||
fn_ptr_symbol: Symbol,
|
||||
fn_arg_layouts: &'a [InLayout<'a>],
|
||||
fn_arguments: &'a [Symbol],
|
||||
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
|
||||
move |rest| {
|
||||
Stmt::Let(
|
||||
result_symbol,
|
||||
Expr::Call(Call {
|
||||
call_type: CallType::ByPointer {
|
||||
pointer: fn_ptr_symbol,
|
||||
ret_layout: result,
|
||||
arg_layouts: fn_arg_layouts,
|
||||
},
|
||||
arguments: fn_arguments,
|
||||
}),
|
||||
result,
|
||||
arena.alloc(rest),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn is_null<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
arena: &'a Bump,
|
||||
assign_to: Symbol,
|
||||
ptr_symbol: Symbol,
|
||||
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
|
||||
let null_symbol = env.unique_symbol();
|
||||
move |rest| {
|
||||
Stmt::Let(
|
||||
null_symbol,
|
||||
Expr::NullPointer,
|
||||
Layout::OPAQUE_PTR,
|
||||
arena.alloc(Stmt::Let(
|
||||
assign_to,
|
||||
Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::Eq,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: arena.alloc([ptr_symbol, null_symbol]),
|
||||
}),
|
||||
Layout::BOOL,
|
||||
arena.alloc(rest),
|
||||
)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Given
|
||||
///
|
||||
/// ```
|
||||
/// Call(f, args)
|
||||
/// ```
|
||||
///
|
||||
/// We generate
|
||||
///
|
||||
/// ```
|
||||
/// f = compile(f)
|
||||
/// joinpoint join result:
|
||||
/// <hole>
|
||||
/// if (f.value) {
|
||||
/// f = cast(f.callee, (..params, void*) -> ret);
|
||||
/// result = f ..args f.value
|
||||
/// jump join result
|
||||
/// } else {
|
||||
/// f = cast(f.callee, (..params) -> ret);
|
||||
/// result = f ..args
|
||||
/// jump join result
|
||||
/// }
|
||||
/// ```
|
||||
pub fn call_erased_function<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
layout_cache: &mut LayoutCache<'a>,
|
||||
procs: &mut Procs<'a>,
|
||||
function_expr: roc_can::expr::Expr,
|
||||
function_var: Variable,
|
||||
function_signature: (&'a [InLayout<'a>], InLayout<'a>),
|
||||
function_argument_symbols: &'a [Symbol],
|
||||
call_result_symbol: Symbol,
|
||||
hole: &'a Stmt<'a>,
|
||||
hole_layout: InLayout<'a>,
|
||||
) -> Stmt<'a> {
|
||||
let arena = env.arena;
|
||||
let (f_args, f_ret) = function_signature;
|
||||
|
||||
let f = env.unique_symbol();
|
||||
|
||||
let join_point_id = JoinPointId(env.unique_symbol());
|
||||
|
||||
// f_value = f.value
|
||||
let f_value = env.unique_symbol();
|
||||
let let_f_value = index_erased_function(arena, f_value, f, ErasedFunctionIndex::Value);
|
||||
|
||||
// f_callee = f.callee
|
||||
let f_callee = env.unique_symbol();
|
||||
let let_f_callee = index_erased_function(arena, f_callee, f, ErasedFunctionIndex::Callee);
|
||||
|
||||
let mut build_closure_data_branch = |env: &mut Env, pass_closure| {
|
||||
// f_callee = cast(f_callee, (..params) -> ret);
|
||||
// result = f_callee ..args
|
||||
// jump join result
|
||||
|
||||
let (f_args, function_argument_symbols) = if pass_closure {
|
||||
// f_args = ...args, f.value
|
||||
// function_argument_symbols = ...args, f.value
|
||||
let f_args = {
|
||||
let mut args = AVec::with_capacity_in(f_args.len() + 1, arena);
|
||||
args.extend(f_args.iter().chain(&[Layout::OPAQUE_PTR]).copied());
|
||||
args.into_bump_slice()
|
||||
};
|
||||
let function_argument_symbols = {
|
||||
let mut args = AVec::with_capacity_in(function_argument_symbols.len() + 1, arena);
|
||||
args.extend(function_argument_symbols.iter().chain(&[f_value]));
|
||||
args.into_bump_slice()
|
||||
};
|
||||
(f_args, function_argument_symbols)
|
||||
} else {
|
||||
(f_args, function_argument_symbols)
|
||||
};
|
||||
|
||||
let fn_ptr_layout =
|
||||
layout_cache.put_in_direct_no_semantic(LayoutRepr::FunctionPointer(FunctionPointer {
|
||||
args: f_args,
|
||||
ret: f_ret,
|
||||
}));
|
||||
|
||||
let f_callee_cast = env.unique_symbol();
|
||||
let let_f_callee_cast = cast_erased_callee(arena, f_callee_cast, f_callee, fn_ptr_layout);
|
||||
|
||||
let result = env.unique_symbol();
|
||||
let let_result = call_callee(
|
||||
arena,
|
||||
result,
|
||||
f_ret,
|
||||
f_callee_cast,
|
||||
f_args,
|
||||
function_argument_symbols,
|
||||
);
|
||||
|
||||
let_f_callee_cast(
|
||||
//
|
||||
let_result(
|
||||
//
|
||||
Stmt::Jump(join_point_id, arena.alloc([result])),
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let value_is_null = env.unique_symbol();
|
||||
let let_value_is_null = is_null(env, arena, value_is_null, f_value);
|
||||
|
||||
let call_and_jump_on_value = let_value_is_null(
|
||||
//
|
||||
Stmt::Switch {
|
||||
cond_symbol: value_is_null,
|
||||
cond_layout: Layout::BOOL,
|
||||
// value == null
|
||||
branches: arena.alloc([(0, BranchInfo::None, build_closure_data_branch(env, false))]),
|
||||
// value != null
|
||||
default_branch: (
|
||||
BranchInfo::None,
|
||||
arena.alloc(build_closure_data_branch(env, true)),
|
||||
),
|
||||
ret_layout: hole_layout,
|
||||
},
|
||||
);
|
||||
|
||||
let joinpoint = {
|
||||
let param = Param {
|
||||
symbol: call_result_symbol,
|
||||
layout: f_ret,
|
||||
ownership: Ownership::Owned,
|
||||
};
|
||||
|
||||
let remainder =
|
||||
// f_value = f.value
|
||||
let_f_value(
|
||||
// f_callee = f.callee
|
||||
let_f_callee(
|
||||
//
|
||||
call_and_jump_on_value,
|
||||
),
|
||||
);
|
||||
|
||||
Stmt::Join {
|
||||
id: join_point_id,
|
||||
parameters: env.arena.alloc([param]),
|
||||
body: hole,
|
||||
remainder: arena.alloc(remainder),
|
||||
}
|
||||
};
|
||||
|
||||
// Compile the function expression into f_val
|
||||
with_hole(
|
||||
env,
|
||||
function_expr,
|
||||
function_var,
|
||||
procs,
|
||||
layout_cache,
|
||||
f,
|
||||
env.arena.alloc(joinpoint),
|
||||
)
|
||||
}
|
|
@ -670,6 +670,7 @@ impl<'a> TrmcEnv<'a> {
|
|||
// because we do not allow polymorphic recursion, this is the only constraint
|
||||
name == lambda_name
|
||||
}
|
||||
CallType::ByPointer { .. } => false,
|
||||
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
|
||||
false
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue