Allow direct packing, unpacking of erased types

This commit is contained in:
Ayaz Hafiz 2023-07-02 13:17:44 -05:00
parent 1d1db83cc7
commit cd64134b0a
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
6 changed files with 296 additions and 185 deletions

View file

@ -29,6 +29,7 @@ pub enum UseKind {
ExpectLookup,
ErasedMake(ErasedField),
Erased,
FunctionPointer,
}
pub enum ProblemKind<'a> {
@ -119,6 +120,24 @@ pub enum ProblemKind<'a> {
num_needed: usize,
num_given: usize,
},
ErasedMakeValueNotBoxed {
symbol: Symbol,
def_layout: InLayout<'a>,
def_line: usize,
},
ErasedMakeCalleeNotFunctionPointer {
symbol: Symbol,
def_layout: InLayout<'a>,
def_line: usize,
},
ErasedLoadValueNotBoxed {
symbol: Symbol,
target_layout: InLayout<'a>,
},
ErasedLoadCalleeNotFunctionPointer {
symbol: Symbol,
target_layout: InLayout<'a>,
},
}
pub struct Problem<'a> {
@ -276,7 +295,7 @@ impl<'a, 'r> Ctx<'a, 'r> {
match body {
Stmt::Let(x, e, x_layout, rest) => {
if let Some(e_layout) = self.check_expr(e) {
if let Some(e_layout) = self.check_expr(e, *x_layout) {
if self.not_equiv(e_layout, *x_layout) {
self.problem(ProblemKind::SymbolDefMismatch {
symbol: *x,
@ -393,7 +412,7 @@ impl<'a, 'r> Ctx<'a, 'r> {
}
}
fn check_expr(&mut self, e: &Expr<'a>) -> Option<InLayout<'a>> {
fn check_expr(&mut self, e: &Expr<'a>, target_layout: InLayout<'a>) -> Option<InLayout<'a>> {
match e {
Expr::Literal(_) => None,
Expr::NullPointer => None,
@ -486,11 +505,10 @@ impl<'a, 'r> Ctx<'a, 'r> {
}
}
}),
&Expr::ErasedMake { value, callee } => {
self.check_erased_make(value, callee);
Some(Layout::ERASED)
&Expr::ErasedMake { value, callee } => Some(self.check_erased_make(value, callee)),
&Expr::ErasedLoad { symbol, field } => {
Some(self.check_erased_load(symbol, field, target_layout))
}
&Expr::ErasedLoad { symbol, field } => Some(self.check_erased_load(symbol, field)),
&Expr::FunctionPointer { lambda_name } => {
let lambda_symbol = lambda_name.name();
if !self.procs.iter().any(|((name, proc), _)| {
@ -500,7 +518,7 @@ impl<'a, 'r> Ctx<'a, 'r> {
symbol: lambda_symbol,
});
}
Some(Layout::OPAQUE_PTR)
Some(target_layout)
}
&Expr::Reset {
symbol,
@ -691,7 +709,7 @@ impl<'a, 'r> Ctx<'a, 'r> {
args: arg_layouts,
ret: *ret_layout,
}));
self.check_sym_layout(*pointer, expected_layout, UseKind::SwitchCond);
self.check_sym_layout(*pointer, expected_layout, UseKind::FunctionPointer);
for (arg, wanted_layout) in arguments.iter().zip(arg_layouts.iter()) {
self.check_sym_layout(*arg, *wanted_layout, UseKind::CallArg);
}
@ -756,28 +774,67 @@ impl<'a, 'r> Ctx<'a, 'r> {
}
}
fn check_erased_make(&mut self, value: Option<Symbol>, callee: Symbol) {
fn check_erased_make(&mut self, value: Option<Symbol>, callee: Symbol) -> InLayout<'a> {
if let Some(value) = value {
self.check_sym_layout(
value,
Layout::OPAQUE_PTR,
UseKind::ErasedMake(ErasedField::Value),
);
self.with_sym_layout(value, |this, def_line, layout| {
let repr = this.interner.get_repr(layout);
if !matches!(repr, LayoutRepr::Boxed(_)) {
this.problem(ProblemKind::ErasedMakeValueNotBoxed {
symbol: value,
def_layout: layout,
def_line,
});
}
Option::<()>::None
});
}
self.check_sym_layout(
callee,
Layout::OPAQUE_PTR,
UseKind::ErasedMake(ErasedField::Callee),
);
self.with_sym_layout(callee, |this, def_line, layout| {
let repr = this.interner.get_repr(layout);
if !matches!(repr, LayoutRepr::FunctionPointer(_)) {
this.problem(ProblemKind::ErasedMakeCalleeNotFunctionPointer {
symbol: callee,
def_layout: layout,
def_line,
});
}
Option::<()>::None
});
Layout::ERASED
}
fn check_erased_load(&mut self, symbol: Symbol, field: ErasedField) -> InLayout<'a> {
fn check_erased_load(
&mut self,
symbol: Symbol,
field: ErasedField,
target_layout: InLayout<'a>,
) -> InLayout<'a> {
self.check_sym_layout(symbol, Layout::ERASED, UseKind::Erased);
match field {
ErasedField::Value => Layout::OPAQUE_PTR,
ErasedField::Callee => Layout::OPAQUE_PTR,
ErasedField::Value => {
let repr = self.interner.get_repr(target_layout);
if !matches!(repr, LayoutRepr::Boxed(_)) {
self.problem(ProblemKind::ErasedLoadValueNotBoxed {
symbol,
target_layout,
});
}
}
ErasedField::Callee => {
let repr = self.interner.get_repr(target_layout);
if !matches!(repr, LayoutRepr::FunctionPointer(_)) {
self.problem(ProblemKind::ErasedLoadCalleeNotFunctionPointer {
symbol,
target_layout,
});
}
}
}
target_layout
}
}

View file

@ -415,6 +415,74 @@ where
f.as_string(num_given),
])
}
ProblemKind::ErasedMakeValueNotBoxed {
symbol,
def_layout,
def_line,
} => {
title = "ERASED VALUE IS NOT BOXED";
docs_before = vec![(
def_line,
f.concat([
f.reflow("The value "),
format_symbol(f, interns, symbol),
f.reflow(" defined here"),
]),
)];
f.concat([
f.reflow("must be boxed in order to be erased, but has layout "),
interner.to_doc_top(def_layout, f),
])
}
ProblemKind::ErasedMakeCalleeNotFunctionPointer {
symbol,
def_layout,
def_line,
} => {
title = "ERASED CALLEE IS NOT A FUNCTION POINTER";
docs_before = vec![(
def_line,
f.concat([
f.reflow("The value "),
format_symbol(f, interns, symbol),
f.reflow(" defined here"),
]),
)];
f.concat([
f.reflow(
"must be a function pointer in order to be an erasure callee, but has layout ",
),
interner.to_doc_top(def_layout, f),
])
}
ProblemKind::ErasedLoadValueNotBoxed {
symbol,
target_layout,
} => {
title = "ERASED VALUE IS NOT BOXED";
docs_before = vec![];
f.concat([
f.reflow("The erased value load "),
format_symbol(f, interns, symbol),
f.reflow(" has layout "),
interner.to_doc_top(target_layout, f),
f.reflow(", but should be boxed!"),
])
}
ProblemKind::ErasedLoadCalleeNotFunctionPointer {
symbol,
target_layout,
} => {
title = "ERASED CALLEE IS NOT A FUNCTION POINTER";
docs_before = vec![];
f.concat([
f.reflow("The erased callee load "),
format_symbol(f, interns, symbol),
f.reflow(" has layout "),
interner.to_doc_top(target_layout, f),
f.reflow(", but should be a function pointer!"),
])
}
};
(title, docs_before, doc)
}
@ -443,6 +511,7 @@ fn format_use_kind(use_kind: UseKind) -> &'static str {
ErasedField::Callee => "erased callee field",
},
UseKind::Erased => "erasure",
UseKind::FunctionPointer => "function pointer",
}
}

View file

@ -5268,12 +5268,18 @@ pub fn with_hole<'a>(
RawFunctionLayout::ZeroArgumentThunk(_) => {
unreachable!("a closure syntactically always must have at least one argument")
}
RawFunctionLayout::ErasedFunction(_argument_layouts, _ret_layout) => {
RawFunctionLayout::ErasedFunction(argument_layouts, ret_layout) => {
let captured_symbols = Vec::from_iter_in(captured_symbols, env.arena);
let captured_symbols = captured_symbols.into_bump_slice();
let captured_symbols = CapturedSymbols::Captured(captured_symbols);
let resolved_erased_lambda =
ResolvedErasedLambda::new(env, layout_cache, name, captured_symbols);
let resolved_erased_lambda = ResolvedErasedLambda::new(
env,
layout_cache,
name,
captured_symbols,
argument_layouts,
ret_layout,
);
let inserted = procs.insert_anonymous(
env,
@ -8340,9 +8346,15 @@ fn specialize_symbol<'a>(
)
}
}
RawFunctionLayout::ErasedFunction(..) => {
let erased_lambda =
erased::ResolvedErasedLambda::new(env, layout_cache, original, captured);
RawFunctionLayout::ErasedFunction(argument_layouts, ret_layout) => {
let erased_lambda = erased::ResolvedErasedLambda::new(
env,
layout_cache,
original,
captured,
argument_layouts,
ret_layout,
);
let lambda_name = erased_lambda.lambda_name();
let proc_layout =

View file

@ -17,6 +17,7 @@ fn index_erased_function<'a>(
assign_to: Symbol,
erased_function: Symbol,
field: ErasedField,
layout: InLayout<'a>,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
move |rest| {
Stmt::Let(
@ -25,29 +26,7 @@ fn index_erased_function<'a>(
symbol: erased_function,
field,
},
Layout::OPAQUE_PTR,
arena.alloc(rest),
)
}
}
fn cast_erased_callee<'a>(
arena: &'a Bump,
assign_to: Symbol,
erased_function: Symbol,
fn_ptr_layout: InLayout<'a>,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
move |rest| {
Stmt::Let(
assign_to,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: arena.alloc([erased_function]),
}),
fn_ptr_layout,
layout,
arena.alloc(rest),
)
}
@ -83,13 +62,14 @@ fn is_null<'a>(
arena: &'a Bump,
assign_to: Symbol,
ptr_symbol: Symbol,
layout: InLayout<'a>,
) -> impl FnOnce(Stmt<'a>) -> Stmt<'a> {
let null_symbol = env.unique_symbol();
move |rest| {
Stmt::Let(
null_symbol,
Expr::NullPointer,
Layout::OPAQUE_PTR,
layout,
arena.alloc(Stmt::Let(
assign_to,
Expr::Call(Call {
@ -106,6 +86,39 @@ fn is_null<'a>(
}
}
struct BuiltFunctionPointer<'a> {
function_pointer: InLayout<'a>,
reified_arguments: &'a [InLayout<'a>],
}
fn build_function_pointer<'a>(
arena: &'a Bump,
layout_cache: &mut LayoutCache<'a>,
argument_layouts: &'a [InLayout<'a>],
return_layout: InLayout<'a>,
pass_closure: bool,
) -> BuiltFunctionPointer<'a> {
let reified_arguments = if pass_closure {
let mut args = AVec::with_capacity_in(argument_layouts.len() + 1, arena);
args.extend(argument_layouts.iter().chain(&[Layout::ERASED]).copied());
args.into_bump_slice()
} else {
argument_layouts
};
let fn_ptr_layout = LayoutRepr::FunctionPointer(FunctionPointer {
args: reified_arguments,
ret: return_layout,
});
let function_pointer = layout_cache.put_in_direct_no_semantic(fn_ptr_layout);
BuiltFunctionPointer {
function_pointer,
reified_arguments,
}
}
/// Given
///
/// ```
@ -118,7 +131,7 @@ fn is_null<'a>(
/// f = compile(f)
/// joinpoint join result:
/// <hole>
/// f_value: Ptr<[]> = ErasedLoad(f, .value)
/// f_value: Box<[]> = ErasedLoad(f, .value)
/// f_callee: Ptr<[]> = ErasedLoad(f, .callee)
/// if (f_value != nullptr) {
/// f_callee = Cast(f_callee, (..params, Erased) -> ret);
@ -151,55 +164,44 @@ pub fn call_erased_function<'a>(
// f_value = ErasedLoad(f, .value)
let f_value = env.unique_symbol();
let let_f_value = index_erased_function(arena, f_value, f, ErasedField::Value);
// f_callee = ErasedLoad(f, .callee)
let f_callee = env.unique_symbol();
let let_f_callee = index_erased_function(arena, f_callee, f, ErasedField::Callee);
let let_f_value =
index_erased_function(arena, f_value, f, ErasedField::Value, Layout::OPAQUE_PTR);
let mut build_closure_data_branch = |env: &mut Env, pass_closure| {
// f_callee = Cast(f_callee, (..params) -> ret);
// result = f_callee ..args
// jump join result
let (f_args, function_argument_symbols) = if pass_closure {
// f_args = ...args, f
let BuiltFunctionPointer {
function_pointer,
reified_arguments: f_args,
} = build_function_pointer(arena, layout_cache, f_args, f_ret, pass_closure);
// f_callee = ErasedLoad(f, .callee)
let f_callee = env.unique_symbol();
let let_f_callee =
index_erased_function(arena, f_callee, f, ErasedField::Callee, function_pointer);
let function_argument_symbols = if pass_closure {
// function_argument_symbols = ...args, f.value
let f_args = {
let mut args = AVec::with_capacity_in(f_args.len() + 1, arena);
args.extend(f_args.iter().chain(&[Layout::ERASED]).copied());
args.into_bump_slice()
};
let function_argument_symbols = {
let mut args = AVec::with_capacity_in(function_argument_symbols.len() + 1, arena);
args.extend(function_argument_symbols.iter().chain(&[f]));
args.into_bump_slice()
};
(f_args, function_argument_symbols)
let mut args = AVec::with_capacity_in(function_argument_symbols.len() + 1, arena);
args.extend(function_argument_symbols.iter().chain(&[f]));
args.into_bump_slice()
} else {
(f_args, function_argument_symbols)
function_argument_symbols
};
let fn_ptr_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::FunctionPointer(FunctionPointer {
args: f_args,
ret: f_ret,
}));
let f_callee_cast = env.unique_symbol();
let let_f_callee_cast = cast_erased_callee(arena, f_callee_cast, f_callee, fn_ptr_layout);
let result = env.unique_symbol();
let let_result = call_callee(
arena,
result,
f_ret,
f_callee_cast,
f_callee,
f_args,
function_argument_symbols,
);
let_f_callee_cast(
let_f_callee(
//
let_result(
//
@ -209,7 +211,7 @@ pub fn call_erased_function<'a>(
};
let value_is_null = env.unique_symbol();
let let_value_is_null = is_null(env, arena, value_is_null, f_value);
let let_value_is_null = is_null(env, arena, value_is_null, f_value, Layout::OPAQUE_PTR);
let call_and_jump_on_value = let_value_is_null(
//
@ -234,14 +236,10 @@ pub fn call_erased_function<'a>(
ownership: Ownership::Owned,
};
let remainder =
let remainder = let_f_value(
// f_value = ErasedLoad(f, .value)
let_f_value(
// f_callee = ErasedLoad(f, .callee)
let_f_callee(
//
call_and_jump_on_value,
),
// <rest>
call_and_jump_on_value,
);
Stmt::Join {
@ -273,8 +271,7 @@ pub fn call_erased_function<'a>(
/// We generate
///
/// ```
/// boxed_value = Expr::Box({s})
/// stack_value: Ptr<[]> = Cast(boxed_value, Ptr<[]>)
/// value = Expr::Box({s})
/// callee = Expr::FunctionPointer(f)
/// f = Expr::ErasedMake({ value, callee })
/// ```
@ -301,6 +298,8 @@ pub fn build_erased_function<'a>(
let ResolvedErasedLambda {
captures,
lambda_name,
arguments,
ret,
} = resolved_lambda;
let value = match captures {
@ -319,11 +318,16 @@ pub fn build_erased_function<'a>(
hole,
);
let BuiltFunctionPointer {
function_pointer,
reified_arguments: _,
} = build_function_pointer(env.arena, layout_cache, arguments, ret, captures.is_some());
// callee = Expr::FunctionPointer(f)
let result = Stmt::Let(
callee,
Expr::FunctionPointer { lambda_name },
Layout::OPAQUE_PTR,
function_pointer,
env.arena.alloc(result),
);
@ -340,25 +344,11 @@ pub fn build_erased_function<'a>(
let stack_captures_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::Struct(layouts));
let boxed_captures = env.unique_symbol();
let boxed_captures_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::Boxed(stack_captures_layout));
let result = Stmt::Let(
value.unwrap(),
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([boxed_captures]),
}),
Layout::OPAQUE_PTR,
env.arena.alloc(result),
);
let result = Stmt::Let(
boxed_captures,
Expr::ExprBox {
symbol: stack_captures,
},
@ -386,6 +376,8 @@ struct ResolvedErasedCaptures<'a> {
pub struct ResolvedErasedLambda<'a> {
captures: Option<ResolvedErasedCaptures<'a>>,
lambda_name: LambdaName<'a>,
arguments: &'a [InLayout<'a>],
ret: InLayout<'a>,
}
impl<'a> ResolvedErasedLambda<'a> {
@ -394,6 +386,8 @@ impl<'a> ResolvedErasedLambda<'a> {
layout_cache: &mut LayoutCache<'a>,
lambda_symbol: Symbol,
captures: CapturedSymbols<'a>,
arguments: &'a [InLayout<'a>],
ret: InLayout<'a>,
) -> Self {
let resolved_captures;
let lambda_name;
@ -422,6 +416,8 @@ impl<'a> ResolvedErasedLambda<'a> {
Self {
captures: resolved_captures,
lambda_name,
arguments,
ret,
}
}
@ -454,7 +450,6 @@ pub fn unpack_closure_data<'a>(
captures: &[(Symbol, Variable)],
mut hole: Stmt<'a>,
) -> Stmt<'a> {
let loaded_captures = env.unique_symbol();
let heap_captures = env.unique_symbol();
let stack_captures = env.unique_symbol();
@ -494,24 +489,12 @@ pub fn unpack_closure_data<'a>(
env.arena.alloc(hole),
);
hole = Stmt::Let(
heap_captures,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([captures_symbol]),
}),
heap_captures_layout,
env.arena.alloc(hole),
);
let let_loaded_captures = index_erased_function(
env.arena,
loaded_captures,
heap_captures,
captures_symbol,
ErasedField::Value,
heap_captures_layout,
);
let_loaded_captures(hole)

View file

@ -722,9 +722,9 @@ impl<'a> FunctionPointer<'a> {
let ret = interner.to_doc(ret, alloc, seen_rec, parens);
alloc
.text("FunPtr(")
.text("FunPtr((")
.append(args)
.append(alloc.text(" -> "))
.append(alloc.text(") -> "))
.append(ret)
.append(")")
}