Merge pull request #666 from rtfeldman/return-function

Return function pointers and closures
This commit is contained in:
Richard Feldman 2020-11-08 22:21:01 -05:00 committed by GitHub
commit 54de538952
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 397 additions and 461 deletions

View file

@ -14,6 +14,8 @@ use roc_types::subs::{Content, FlatType, Subs, Variable};
use std::collections::HashMap;
use ven_pretty::{BoxAllocator, DocAllocator, DocBuilder};
pub const PRETTY_PRINT_IR_SYMBOLS: bool = false;
#[derive(Clone, Debug, PartialEq)]
pub enum MonoProblem {
PatternProblem(crate::exhaustive::Error),
@ -902,8 +904,11 @@ where
D::Doc: Clone,
A: Clone,
{
alloc.text(format!("{}", symbol))
// alloc.text(format!("{:?}", symbol))
if PRETTY_PRINT_IR_SYMBOLS {
alloc.text(format!("{:?}", symbol))
} else {
alloc.text(format!("{}", symbol))
}
}
fn join_point_to_doc<'b, D, A>(alloc: &'b D, symbol: JoinPointId) -> DocBuilder<'b, D, A>
@ -1514,51 +1519,9 @@ fn specialize_external<'a>(
pattern_symbols
};
let (proc_args, opt_closure_layout, ret_layout) =
let specialized =
build_specialized_proc_from_var(env, layout_cache, proc_name, pattern_symbols, fn_var)?;
let mut specialized_body = from_can(env, fn_var, body, procs, layout_cache);
// unpack the closure symbols, if any
if let CapturedSymbols::Captured(captured) = captured_symbols {
let mut layouts = Vec::with_capacity_in(captured.len(), env.arena);
for (_, variable) in captured.iter() {
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
layouts.push(layout);
}
let field_layouts = layouts.into_bump_slice();
let wrapped = match &opt_closure_layout {
Some(x) => x.get_wrapped(),
None => unreachable!("symbols are captured, so this must be a closure"),
};
for (index, (symbol, variable)) in captured.iter().enumerate() {
// layout is cached anyway, re-using the one found above leads to
// issues (combining by-ref and by-move in pattern match
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
// if the symbol has a layout that is dropped from data structures (e.g. `{}`)
// then regenerate the symbol here. The value may not be present in the closure
// data struct
let expr = {
if layout.is_dropped_because_empty() {
Expr::Struct(&[])
} else {
Expr::AccessAtIndex {
index: index as _,
field_layouts,
structure: Symbol::ARG_CLOSURE,
wrapped,
}
}
};
specialized_body = Stmt::Let(*symbol, expr, layout, env.arena.alloc(specialized_body));
}
}
// determine the layout of aliases/rigids exposed to the host
let host_exposed_layouts = if host_exposed_variables.is_empty() {
HostExposedLayouts::NotHostExposed
@ -1578,39 +1541,141 @@ fn specialize_external<'a>(
}
};
// reset subs, so we don't get type errors when specializing for a different signature
layout_cache.rollback_to(cache_snapshot);
env.subs.rollback_to(snapshot);
let recursivity = if is_self_recursive {
SelfRecursive::SelfRecursive(JoinPointId(env.unique_symbol()))
} else {
SelfRecursive::NotSelfRecursive
};
let closure_data_layout = match opt_closure_layout {
Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)),
None => None,
};
let mut specialized_body = from_can(env, fn_var, body, procs, layout_cache);
let proc = Proc {
name: proc_name,
args: proc_args,
body: specialized_body,
closure_data_layout,
ret_layout,
is_self_recursive: recursivity,
host_exposed_layouts,
};
match specialized {
SpecializedLayout::FunctionPointerBody {
arguments,
ret_layout,
closure: opt_closure_layout,
} => {
// this is a function body like
//
// foo = Num.add
//
// we need to expand this to
//
// foo = \x,y -> Num.add x y
Ok(proc)
// reset subs, so we don't get type errors when specializing for a different signature
layout_cache.rollback_to(cache_snapshot);
env.subs.rollback_to(snapshot);
let closure_data_layout = match opt_closure_layout {
Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)),
None => None,
};
// I'm not sure how to handle the closure case, does it ever occur?
debug_assert_eq!(closure_data_layout, None);
debug_assert!(matches!(captured_symbols, CapturedSymbols::None));
// this will be a thunk returning a function, so its ret_layout must be a function!
let full_layout = Layout::FunctionPointer(arguments, env.arena.alloc(ret_layout));
let proc = Proc {
name: proc_name,
args: &[],
body: specialized_body,
closure_data_layout,
ret_layout: full_layout,
is_self_recursive: recursivity,
host_exposed_layouts,
};
Ok(proc)
}
SpecializedLayout::FunctionBody {
arguments: proc_args,
closure: opt_closure_layout,
ret_layout,
} => {
// unpack the closure symbols, if any
if let CapturedSymbols::Captured(captured) = captured_symbols {
let mut layouts = Vec::with_capacity_in(captured.len(), env.arena);
for (_, variable) in captured.iter() {
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
layouts.push(layout);
}
let field_layouts = layouts.into_bump_slice();
let wrapped = match &opt_closure_layout {
Some(x) => x.get_wrapped(),
None => unreachable!("symbols are captured, so this must be a closure"),
};
for (index, (symbol, variable)) in captured.iter().enumerate() {
// layout is cached anyway, re-using the one found above leads to
// issues (combining by-ref and by-move in pattern match
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
// if the symbol has a layout that is dropped from data structures (e.g. `{}`)
// then regenerate the symbol here. The value may not be present in the closure
// data struct
let expr = {
if layout.is_dropped_because_empty() {
Expr::Struct(&[])
} else {
Expr::AccessAtIndex {
index: index as _,
field_layouts,
structure: Symbol::ARG_CLOSURE,
wrapped,
}
}
};
specialized_body =
Stmt::Let(*symbol, expr, layout, env.arena.alloc(specialized_body));
}
}
// reset subs, so we don't get type errors when specializing for a different signature
layout_cache.rollback_to(cache_snapshot);
env.subs.rollback_to(snapshot);
let closure_data_layout = match opt_closure_layout {
Some(closure_layout) => Some(closure_layout.as_named_layout(proc_name)),
None => None,
};
let proc = Proc {
name: proc_name,
args: proc_args,
body: specialized_body,
closure_data_layout,
ret_layout,
is_self_recursive: recursivity,
host_exposed_layouts,
};
Ok(proc)
}
}
}
type SpecializedLayout<'a> = (
&'a [(Layout<'a>, Symbol)],
Option<ClosureLayout<'a>>,
Layout<'a>,
);
enum SpecializedLayout<'a> {
/// A body like `foo = \a,b,c -> ...`
FunctionBody {
arguments: &'a [(Layout<'a>, Symbol)],
closure: Option<ClosureLayout<'a>>,
ret_layout: Layout<'a>,
},
/// A body like `foo = Num.add`
FunctionPointerBody {
arguments: &'a [Layout<'a>],
closure: Option<ClosureLayout<'a>>,
ret_layout: Layout<'a>,
},
}
#[allow(clippy::type_complexity)]
fn build_specialized_proc_from_var<'a>(
@ -1739,6 +1804,7 @@ fn build_specialized_proc<'a>(
let mut proc_args = Vec::with_capacity_in(pattern_layouts.len(), arena);
let pattern_layouts_len = pattern_layouts.len();
let pattern_layouts_slice = pattern_layouts.clone().into_bump_slice();
for (arg_layout, arg_name) in pattern_layouts.into_iter().zip(pattern_symbols.iter()) {
proc_args.push((arg_layout, *arg_name));
@ -1761,6 +1827,7 @@ fn build_specialized_proc<'a>(
// f_closure = { ptr: f, closure: x }
//
// then
use SpecializedLayout::*;
match opt_closure_layout {
Some(layout) if pattern_symbols.last() == Some(&Symbol::ARG_CLOSURE) => {
// here we define the lifted (now top-level) f function. Its final argument is `Symbol::ARG_CLOSURE`,
@ -1776,7 +1843,11 @@ fn build_specialized_proc<'a>(
let proc_args = proc_args.into_bump_slice();
Ok((proc_args, Some(layout), ret_layout))
Ok(FunctionBody {
arguments: proc_args,
closure: Some(layout),
ret_layout,
})
}
Some(layout) => {
// else if there is a closure layout, we're building the `f_closure` value
@ -1791,7 +1862,11 @@ fn build_specialized_proc<'a>(
let closure_layout =
Layout::Struct(arena.alloc([function_ptr_layout, closure_data_layout]));
Ok((&[], None, closure_layout))
Ok(FunctionBody {
arguments: &[],
closure: None,
ret_layout: closure_layout,
})
}
None => {
// else we're making a normal function, no closure problems to worry about
@ -1805,16 +1880,32 @@ fn build_specialized_proc<'a>(
Ordering::Equal => {
let proc_args = proc_args.into_bump_slice();
Ok((proc_args, None, ret_layout))
Ok(FunctionBody {
arguments: proc_args,
closure: None,
ret_layout,
})
}
Ordering::Greater => {
// so far, the problem when hitting this branch was always somewhere else
// I think this branch should not be reachable in a bugfree compiler
panic!("more arguments (according to the layout) than argument symbols")
}
Ordering::Less => {
panic!("more argument symbols than arguments (according to the layout)")
if pattern_symbols.is_empty() {
Ok(FunctionPointerBody {
arguments: pattern_layouts_slice,
closure: None,
ret_layout,
})
} else {
// so far, the problem when hitting this branch was always somewhere else
// I think this branch should not be reachable in a bugfree compiler
panic!(
"more arguments (according to the layout) than argument symbols for {:?}",
proc_name
)
}
}
Ordering::Less => panic!(
"more argument symbols than arguments (according to the layout) for {:?}",
proc_name
),
}
}
}
@ -2213,8 +2304,22 @@ pub fn with_hole<'a>(
} else if symbol.module_id() != env.home && symbol.module_id() != ModuleId::ATTR {
match layout_cache.from_var(env.arena, variable, env.subs) {
Err(e) => panic!("invalid layout {:?}", e),
Ok(Layout::FunctionPointer(_, _)) => {
Ok(layout @ Layout::FunctionPointer(_, _)) => {
add_needed_external(procs, env, variable, symbol);
match hole {
Stmt::Jump(_, _) => todo!("not sure what to do in this case yet"),
_ => {
let expr = Expr::FunctionPointer(symbol, layout.clone());
let new_symbol = env.unique_symbol();
return Stmt::Let(
new_symbol,
expr,
layout,
env.arena.alloc(Stmt::Ret(new_symbol)),
);
}
}
}
Ok(_) => {
// this is a 0-arity thunk
@ -4678,6 +4783,12 @@ fn call_by_name<'a>(
.specialized
.contains_key(&(proc_name, full_layout.clone()))
{
debug_assert_eq!(
arg_layouts.len(),
field_symbols.len(),
"see call_by_name for background (scroll down a bit)"
);
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name),
ret_layout: ret_layout.clone(),
@ -4721,6 +4832,11 @@ fn call_by_name<'a>(
);
}
debug_assert_eq!(
arg_layouts.len(),
field_symbols.len(),
"see call_by_name for background (scroll down a bit)"
);
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name),
ret_layout: ret_layout.clone(),
@ -4762,30 +4878,102 @@ fn call_by_name<'a>(
let function_layout =
FunctionLayouts::from_layout(env.arena, layout);
procs.specialized.remove(&(proc_name, full_layout));
procs.specialized.remove(&(proc_name, full_layout.clone()));
procs.specialized.insert(
(proc_name, function_layout.full.clone()),
Done(proc),
);
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name),
ret_layout: function_layout.result.clone(),
full_layout: function_layout.full,
arg_layouts: function_layout.arguments,
args: field_symbols,
};
if field_symbols.is_empty() {
debug_assert!(loc_args.is_empty());
let iter = loc_args
.into_iter()
.rev()
.zip(field_symbols.iter().rev());
// This happens when we return a function, e.g.
//
// foo = Num.add
//
// Even though the layout (and type) are functions,
// there are no arguments. This confuses our IR,
// and we have to fix it here.
match full_layout {
Layout::Closure(_, closure_layout, _) => {
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name),
ret_layout: function_layout.result.clone(),
full_layout: function_layout.full.clone(),
arg_layouts: function_layout.arguments,
args: field_symbols,
};
let result =
Stmt::Let(assigned, call, function_layout.result, hole);
// in the case of a closure specifically, we
// have to create a custom layout, to make sure
// the closure data is part of the layout
let closure_struct_layout = Layout::Struct(
env.arena.alloc([
function_layout.full,
closure_layout
.as_block_of_memory_layout(),
]),
);
assign_to_symbols(env, procs, layout_cache, iter, result)
Stmt::Let(
assigned,
call,
closure_struct_layout,
hole,
)
}
_ => {
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name),
ret_layout: function_layout.result.clone(),
full_layout: function_layout.full.clone(),
arg_layouts: function_layout.arguments,
args: field_symbols,
};
Stmt::Let(
assigned,
call,
function_layout.full,
hole,
)
}
}
} else {
debug_assert_eq!(
function_layout.arguments.len(),
field_symbols.len(),
"scroll up a bit for background"
);
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name),
ret_layout: function_layout.result.clone(),
full_layout: function_layout.full,
arg_layouts: function_layout.arguments,
args: field_symbols,
};
let iter = loc_args
.into_iter()
.rev()
.zip(field_symbols.iter().rev());
let result = Stmt::Let(
assigned,
call,
function_layout.result,
hole,
);
assign_to_symbols(
env,
procs,
layout_cache,
iter,
result,
)
}
}
Err(error) => {
let error_msg = env.arena.alloc(format!(
@ -4804,6 +4992,12 @@ fn call_by_name<'a>(
None if assigned.module_id() != proc_name.module_id() => {
add_needed_external(procs, env, original_fn_var, proc_name);
debug_assert_eq!(
arg_layouts.len(),
field_symbols.len(),
"scroll up a bit for background"
);
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name),
ret_layout: ret_layout.clone(),