Allow direct packing, unpacking of erased types

This commit is contained in:
Ayaz Hafiz 2023-07-02 13:17:44 -05:00
parent 1d1db83cc7
commit cd64134b0a
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
6 changed files with 296 additions and 185 deletions

View file

@ -17,6 +17,7 @@ fn index_erased_function<'a>(
assign_to: Symbol,
erased_function: Symbol,
field: ErasedField,
layout: InLayout<'a>,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
move |rest| {
Stmt::Let(
@ -25,29 +26,7 @@ fn index_erased_function<'a>(
symbol: erased_function,
field,
},
Layout::OPAQUE_PTR,
arena.alloc(rest),
)
}
}
fn cast_erased_callee<'a>(
arena: &'a Bump,
assign_to: Symbol,
erased_function: Symbol,
fn_ptr_layout: InLayout<'a>,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
move |rest| {
Stmt::Let(
assign_to,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: arena.alloc([erased_function]),
}),
fn_ptr_layout,
layout,
arena.alloc(rest),
)
}
@ -83,13 +62,14 @@ fn is_null<'a>(
arena: &'a Bump,
assign_to: Symbol,
ptr_symbol: Symbol,
layout: InLayout<'a>,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
let null_symbol = env.unique_symbol();
move |rest| {
Stmt::Let(
null_symbol,
Expr::NullPointer,
Layout::OPAQUE_PTR,
layout,
arena.alloc(Stmt::Let(
assign_to,
Expr::Call(Call {
@ -106,6 +86,39 @@ fn is_null<'a>(
}
}
struct BuiltFunctionPointer<'a> {
function_pointer: InLayout<'a>,
reified_arguments: &'a [InLayout<'a>],
}
fn build_function_pointer<'a>(
arena: &'a Bump,
layout_cache: &mut LayoutCache<'a>,
argument_layouts: &'a [InLayout<'a>],
return_layout: InLayout<'a>,
pass_closure: bool,
) -> BuiltFunctionPointer<'a> {
let reified_arguments = if pass_closure {
let mut args = AVec::with_capacity_in(argument_layouts.len() + 1, arena);
args.extend(argument_layouts.iter().chain(&[Layout::ERASED]).copied());
args.into_bump_slice()
} else {
argument_layouts
};
let fn_ptr_layout = LayoutRepr::FunctionPointer(FunctionPointer {
args: reified_arguments,
ret: return_layout,
});
let function_pointer = layout_cache.put_in_direct_no_semantic(fn_ptr_layout);
BuiltFunctionPointer {
function_pointer,
reified_arguments,
}
}
/// Given
///
/// ```
@ -118,7 +131,7 @@ fn is_null<'a>(
/// f = compile(f)
/// joinpoint join result:
/// <hole>
/// f_value: Ptr<[]> = ErasedLoad(f, .value)
/// f_value: Box<[]> = ErasedLoad(f, .value)
/// f_callee: Ptr<[]> = ErasedLoad(f, .callee)
/// if (f_value != nullptr) {
/// f_callee = Cast(f_callee, (..params, Erased) -> ret);
@ -151,55 +164,44 @@ pub fn call_erased_function<'a>(
// f_value = ErasedLoad(f, .value)
let f_value = env.unique_symbol();
let let_f_value = index_erased_function(arena, f_value, f, ErasedField::Value);
// f_callee = ErasedLoad(f, .callee)
let f_callee = env.unique_symbol();
let let_f_callee = index_erased_function(arena, f_callee, f, ErasedField::Callee);
let let_f_value =
index_erased_function(arena, f_value, f, ErasedField::Value, Layout::OPAQUE_PTR);
let mut build_closure_data_branch = |env: &mut Env, pass_closure| {
// f_callee = Cast(f_callee, (..params) -> ret);
// result = f_callee ..args
// jump join result
let (f_args, function_argument_symbols) = if pass_closure {
// f_args = ...args, f
let BuiltFunctionPointer {
function_pointer,
reified_arguments: f_args,
} = build_function_pointer(arena, layout_cache, f_args, f_ret, pass_closure);
// f_callee = ErasedLoad(f, .callee)
let f_callee = env.unique_symbol();
let let_f_callee =
index_erased_function(arena, f_callee, f, ErasedField::Callee, function_pointer);
let function_argument_symbols = if pass_closure {
// function_argument_symbols = ...args, f.value
let f_args = {
let mut args = AVec::with_capacity_in(f_args.len() + 1, arena);
args.extend(f_args.iter().chain(&[Layout::ERASED]).copied());
args.into_bump_slice()
};
let function_argument_symbols = {
let mut args = AVec::with_capacity_in(function_argument_symbols.len() + 1, arena);
args.extend(function_argument_symbols.iter().chain(&[f]));
args.into_bump_slice()
};
(f_args, function_argument_symbols)
let mut args = AVec::with_capacity_in(function_argument_symbols.len() + 1, arena);
args.extend(function_argument_symbols.iter().chain(&[f]));
args.into_bump_slice()
} else {
(f_args, function_argument_symbols)
function_argument_symbols
};
let fn_ptr_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::FunctionPointer(FunctionPointer {
args: f_args,
ret: f_ret,
}));
let f_callee_cast = env.unique_symbol();
let let_f_callee_cast = cast_erased_callee(arena, f_callee_cast, f_callee, fn_ptr_layout);
let result = env.unique_symbol();
let let_result = call_callee(
arena,
result,
f_ret,
f_callee_cast,
f_callee,
f_args,
function_argument_symbols,
);
let_f_callee_cast(
let_f_callee(
//
let_result(
//
@ -209,7 +211,7 @@ pub fn call_erased_function<'a>(
};
let value_is_null = env.unique_symbol();
let let_value_is_null = is_null(env, arena, value_is_null, f_value);
let let_value_is_null = is_null(env, arena, value_is_null, f_value, Layout::OPAQUE_PTR);
let call_and_jump_on_value = let_value_is_null(
//
@ -234,14 +236,10 @@ pub fn call_erased_function<'a>(
ownership: Ownership::Owned,
};
let remainder =
let remainder = let_f_value(
// f_value = ErasedLoad(f, .value)
let_f_value(
// f_callee = ErasedLoad(f, .callee)
let_f_callee(
//
call_and_jump_on_value,
),
// <rest>
call_and_jump_on_value,
);
Stmt::Join {
@ -273,8 +271,7 @@ pub fn call_erased_function<'a>(
/// We generate
///
/// ```
/// boxed_value = Expr::Box({s})
/// stack_value: Ptr<[]> = Cast(boxed_value, Ptr<[]>)
/// value = Expr::Box({s})
/// callee = Expr::FunctionPointer(f)
/// f = Expr::ErasedMake({ value, callee })
/// ```
@ -301,6 +298,8 @@ pub fn build_erased_function<'a>(
let ResolvedErasedLambda {
captures,
lambda_name,
arguments,
ret,
} = resolved_lambda;
let value = match captures {
@ -319,11 +318,16 @@ pub fn build_erased_function<'a>(
hole,
);
let BuiltFunctionPointer {
function_pointer,
reified_arguments: _,
} = build_function_pointer(env.arena, layout_cache, arguments, ret, captures.is_some());
// callee = Expr::FunctionPointer(f)
let result = Stmt::Let(
callee,
Expr::FunctionPointer { lambda_name },
Layout::OPAQUE_PTR,
function_pointer,
env.arena.alloc(result),
);
@ -340,25 +344,11 @@ pub fn build_erased_function<'a>(
let stack_captures_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::Struct(layouts));
let boxed_captures = env.unique_symbol();
let boxed_captures_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::Boxed(stack_captures_layout));
let result = Stmt::Let(
value.unwrap(),
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([boxed_captures]),
}),
Layout::OPAQUE_PTR,
env.arena.alloc(result),
);
let result = Stmt::Let(
boxed_captures,
Expr::ExprBox {
symbol: stack_captures,
},
@ -386,6 +376,8 @@ struct ResolvedErasedCaptures<'a> {
pub struct ResolvedErasedLambda<'a> {
captures: Option<ResolvedErasedCaptures<'a>>,
lambda_name: LambdaName<'a>,
arguments: &'a [InLayout<'a>],
ret: InLayout<'a>,
}
impl<'a> ResolvedErasedLambda<'a> {
@ -394,6 +386,8 @@ impl<'a> ResolvedErasedLambda<'a> {
layout_cache: &mut LayoutCache<'a>,
lambda_symbol: Symbol,
captures: CapturedSymbols<'a>,
arguments: &'a [InLayout<'a>],
ret: InLayout<'a>,
) -> Self {
let resolved_captures;
let lambda_name;
@ -422,6 +416,8 @@ impl<'a> ResolvedErasedLambda<'a> {
Self {
captures: resolved_captures,
lambda_name,
arguments,
ret,
}
}
@ -454,7 +450,6 @@ pub fn unpack_closure_data<'a>(
captures: &[(Symbol, Variable)],
mut hole: Stmt<'a>,
) -> Stmt<'a> {
let loaded_captures = env.unique_symbol();
let heap_captures = env.unique_symbol();
let stack_captures = env.unique_symbol();
@ -494,24 +489,12 @@ pub fn unpack_closure_data<'a>(
env.arena.alloc(hole),
);
hole = Stmt::Let(
heap_captures,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([captures_symbol]),
}),
heap_captures_layout,
env.arena.alloc(hole),
);
let let_loaded_captures = index_erased_function(
env.arena,
loaded_captures,
heap_captures,
captures_symbol,
ErasedField::Value,
heap_captures_layout,
);
let_loaded_captures(hole)