This commit is contained in:
Folkert 2020-10-16 00:18:40 +02:00
parent d0f031fe6c
commit 40ffca2b7b
16 changed files with 511 additions and 200 deletions

View file

@ -23,10 +23,17 @@ pub enum MonoProblem {
pub struct PartialProc<'a> {
pub annotation: Variable,
pub pattern_symbols: &'a [Symbol],
pub captured_symbols: CapturedSymbols<'a>,
pub body: roc_can::expr::Expr,
pub is_self_recursive: bool,
}
#[derive(Clone, Debug, PartialEq)]
pub enum CapturedSymbols<'a> {
None,
Captured(&'a [(Symbol, Variable)]),
}
#[derive(Clone, Debug, PartialEq)]
pub struct PendingSpecialization {
solved_type: SolvedType,
@ -292,7 +299,7 @@ impl<'a> Procs<'a> {
let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result));
for (_, proc) in result.iter_mut() {
crate::inc_dec::visit_proc(arena, borrow_params, proc);
// crate::inc_dec::visit_proc(arena, borrow_params, proc);
}
(result, borrow_params)
@ -308,6 +315,7 @@ impl<'a> Procs<'a> {
annotation: Variable,
loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>,
loc_body: Located<roc_can::expr::Expr>,
captured_symbols: CapturedSymbols<'a>,
is_self_recursive: bool,
ret_var: Variable,
) {
@ -323,6 +331,7 @@ impl<'a> Procs<'a> {
PartialProc {
annotation,
pattern_symbols,
captured_symbols,
body: body.value,
is_self_recursive,
},
@ -354,6 +363,7 @@ impl<'a> Procs<'a> {
annotation: Variable,
loc_args: std::vec::Vec<(Variable, Located<roc_can::pattern::Pattern>)>,
loc_body: Located<roc_can::expr::Expr>,
captured_symbols: CapturedSymbols<'a>,
ret_var: Variable,
layout_cache: &mut LayoutCache<'a>,
) -> Result<Layout<'a>, RuntimeError> {
@ -395,6 +405,7 @@ impl<'a> Procs<'a> {
PartialProc {
annotation,
pattern_symbols,
captured_symbols,
body: body.value,
is_self_recursive,
},
@ -405,6 +416,7 @@ impl<'a> Procs<'a> {
let partial_proc = PartialProc {
annotation,
pattern_symbols,
captured_symbols,
body: body.value,
is_self_recursive,
};
@ -1357,6 +1369,7 @@ fn specialize_external<'a>(
let PartialProc {
annotation,
pattern_symbols,
captured_symbols,
body,
is_self_recursive,
} = partial_proc;
@ -1370,7 +1383,42 @@ fn specialize_external<'a>(
let is_valid = matches!(unified, roc_unify::unify::Unified::Success(_));
debug_assert!(is_valid);
let specialized_body = from_can(env, body, procs, layout_cache);
let mut specialized_body = from_can(env, body, procs, layout_cache);
// if this is a closure, add the closure record argument
let pattern_symbols = if let CapturedSymbols::Captured(_) = captured_symbols {
let mut temp = Vec::from_iter_in(pattern_symbols.iter().copied(), env.arena);
temp.push(Symbol::ARG_CLOSURE);
temp.into_bump_slice()
} else {
pattern_symbols
};
// unpack the closure symbols, if any
if let CapturedSymbols::Captured(captured) = captured_symbols {
let mut layouts = Vec::with_capacity_in(captured.len(), env.arena);
for (_, variable) in captured.iter() {
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
layouts.push(layout);
}
let field_layouts = layouts.into_bump_slice();
for (index, (symbol, variable)) in captured.iter().enumerate() {
let expr = Expr::AccessAtIndex {
index: index as _,
field_layouts,
structure: Symbol::ARG_CLOSURE,
wrapped: Wrapped::RecordOrSingleTagUnion,
};
// layout is cached anyway, re-using the one found above leads to
// issues (combining by-ref and by-move in pattern match
let layout = layout_cache.from_var(env.arena, *variable, env.subs)?;
specialized_body = Stmt::Let(*symbol, expr, layout, env.arena.alloc(specialized_body));
}
}
let (proc_args, closes_over, ret_layout) =
build_specialized_proc_from_var(env, layout_cache, pattern_symbols, fn_var)?;
@ -1441,18 +1489,31 @@ fn build_specialized_proc<'a>(
) -> Result<(&'a [(Layout<'a>, Symbol)], Layout<'a>, Layout<'a>), LayoutProblem> {
let mut proc_args = Vec::with_capacity_in(pattern_vars.len(), &env.arena);
debug_assert_eq!(
&pattern_vars.len(),
&pattern_symbols.len(),
"Tried to zip two vecs with different lengths!"
);
for (arg_var, arg_name) in pattern_vars.iter().zip(pattern_symbols.iter()) {
let layout = layout_cache.from_var(&env.arena, *arg_var, env.subs)?;
proc_args.push((layout, *arg_name));
}
// is the final argument symbol the closure symbol? then add the closure variable to the
// pattern variables
if pattern_symbols.last() == Some(&Symbol::ARG_CLOSURE) {
let layout = layout_cache.from_var(&env.arena, closure_var.unwrap(), env.subs)?;
proc_args.push((layout, Symbol::ARG_CLOSURE));
debug_assert_eq!(
pattern_vars.len() + 1,
pattern_symbols.len(),
"Tried to zip two vecs with different lengths!"
);
} else {
debug_assert_eq!(
pattern_vars.len(),
pattern_symbols.len(),
"Tried to zip two vecs with different lengths!"
);
}
let proc_args = proc_args.into_bump_slice();
let closes_over = match closure_var {
@ -1638,6 +1699,7 @@ pub fn with_hole<'a>(
function_type,
arguments,
loc_body,
CapturedSymbols::None,
is_self_recursive,
return_type,
);
@ -1749,6 +1811,7 @@ pub fn with_hole<'a>(
function_type,
arguments,
loc_body,
CapturedSymbols::None,
is_self_recursive,
return_type,
);
@ -2293,6 +2356,7 @@ pub fn with_hole<'a>(
function_var,
arguments,
loc_body,
CapturedSymbols::None,
field_var,
layout_cache,
) {
@ -2433,6 +2497,7 @@ pub fn with_hole<'a>(
function_type,
arguments,
loc_body,
CapturedSymbols::None,
return_type,
layout_cache,
) {
@ -2509,6 +2574,7 @@ pub fn with_hole<'a>(
let arg_layouts = match full_layout {
Layout::FunctionPointer(args, _) => args,
Layout::Closure(args, _, _) => args,
_ => unreachable!("function has layout that is not function pointer"),
};
@ -2523,18 +2589,96 @@ pub fn with_hole<'a>(
let mut result;
match can_reuse_symbol(procs, &loc_expr.value) {
Some(function_symbol) => {
result = Stmt::Let(
assigned,
Expr::FunctionCall {
call_type: CallType::ByPointer(function_symbol),
full_layout,
ret_layout: ret_layout.clone(),
args: arg_symbols,
if let Layout::Closure(_, closure_fields, _) = full_layout {
// we're invoking a closure
let closure_record_symbol = env.unique_symbol();
let closure_function_symbol = env.unique_symbol();
let closure_symbol = function_symbol;
// layout of the closure record
let closure_record_layout = Layout::Struct(closure_fields);
let arg_symbols = {
let mut temp =
Vec::from_iter_in(arg_symbols.iter().copied(), env.arena);
temp.push(closure_record_symbol);
temp.into_bump_slice()
};
let arg_layouts = {
let mut temp =
Vec::from_iter_in(arg_layouts.iter().cloned(), env.arena);
temp.push(closure_record_layout.clone());
temp.into_bump_slice()
};
// layout of the function itself, so typically FunctionPointer(arg_layouts ++ [closure_record], ret_layout)
let function_ptr_layout = Layout::FunctionPointer(
arg_layouts,
},
ret_layout,
arena.alloc(hole),
);
env.arena.alloc(ret_layout.clone()),
);
// build the call
result = Stmt::Let(
assigned,
Expr::FunctionCall {
call_type: CallType::ByPointer(closure_function_symbol),
full_layout: function_ptr_layout.clone(),
ret_layout: ret_layout.clone(),
args: arg_symbols,
arg_layouts,
},
ret_layout,
arena.alloc(hole),
);
// layout of the ( function_pointer, closure_record ) pair
let closure_layout = env.arena.alloc([
function_ptr_layout.clone(),
closure_record_layout.clone(),
]);
// extract & assign the closure function
let expr = Expr::AccessAtIndex {
index: 0,
field_layouts: closure_layout,
structure: closure_symbol,
wrapped: Wrapped::RecordOrSingleTagUnion,
};
result = Stmt::Let(
closure_function_symbol,
expr,
function_ptr_layout,
env.arena.alloc(result),
);
// extract & assign the closure record
let expr = Expr::AccessAtIndex {
index: 1,
field_layouts: closure_layout,
structure: closure_symbol,
wrapped: Wrapped::RecordOrSingleTagUnion,
};
result = Stmt::Let(
closure_record_symbol,
expr,
closure_record_layout,
env.arena.alloc(result),
);
} else {
result = Stmt::Let(
assigned,
Expr::FunctionCall {
call_type: CallType::ByPointer(function_symbol),
full_layout,
ret_layout: ret_layout.clone(),
args: arg_symbols,
arg_layouts,
},
ret_layout,
arena.alloc(hole),
);
}
}
None => {
let function_symbol = env.unique_symbol();
@ -2709,6 +2853,7 @@ pub fn from_can<'a>(
function_type,
arguments,
loc_body,
CapturedSymbols::None,
is_self_recursive,
return_type,
);
@ -2731,10 +2876,12 @@ pub fn from_can<'a>(
match def.loc_expr.value {
Closure {
function_type,
closure_type,
return_type,
recursive,
arguments,
loc_body: boxed_body,
captured_symbols,
..
} => {
// Extract Procs, but discard the resulting Expr::Load.
@ -2745,26 +2892,51 @@ pub fn from_can<'a>(
let is_self_recursive =
!matches!(recursive, roc_can::expr::Recursive::NotRecursive);
procs.insert_named(
env,
layout_cache,
*symbol,
function_type,
arguments,
loc_body,
is_self_recursive,
return_type,
);
// does this function capture any local values?
let function_layout =
layout_cache.from_var(env.arena, function_type, env.subs);
let is_closure =
matches!(&function_layout, Ok(Layout::Closure(_, _, _)));
if is_closure {
let function_layout = function_layout.unwrap();
let full_layout = function_layout.clone();
if let Ok(Layout::Closure(
argument_layouts,
closure_fields,
ret_layout,
)) = &function_layout
{
let mut captured_symbols =
Vec::from_iter_in(captured_symbols, env.arena);
captured_symbols.sort();
let captured_symbols = captured_symbols.into_bump_slice();
procs.insert_named(
env,
layout_cache,
*symbol,
function_type,
arguments,
loc_body,
CapturedSymbols::Captured(captured_symbols),
is_self_recursive,
return_type,
);
let closure_data_layout = Layout::Struct(closure_fields);
// define the function pointer
let function_ptr_layout = {
let mut temp = Vec::from_iter_in(
argument_layouts.iter().cloned(),
env.arena,
);
temp.push(closure_data_layout.clone());
Layout::FunctionPointer(
temp.into_bump_slice(),
ret_layout.clone(),
)
};
let full_layout = function_ptr_layout.clone();
let fn_var = function_type;
let proc_name = *symbol;
let pending = PendingSpecialization::from_var(env.subs, fn_var);
@ -2816,34 +2988,26 @@ pub fn from_can<'a>(
partial_proc,
) {
Ok((proc, layout)) => {
debug_assert_eq!(full_layout, layout);
// debug_assert_eq!(full_layout, layout);
let function_layout =
FunctionLayouts::from_layout(layout);
procs
.specialized
.remove(&(proc_name, full_layout));
procs.specialized.remove(&(
proc_name,
full_layout.clone(),
));
procs.specialized.insert(
(
proc_name,
function_layout.full.clone(),
// function_layout.full.clone(),
full_layout.clone(),
),
Done(proc),
);
}
Err(error) => {
let error_msg = env.arena.alloc(format!(
"TODO generate a RuntimeError message for {:?}",
error
));
procs
.runtime_errors
.insert(proc_name, error_msg);
panic!();
// Stmt::RuntimeError(error_msg)
}
}
}
@ -2863,33 +3027,54 @@ pub fn from_can<'a>(
stmt = Stmt::Let(
*symbol,
expr,
function_layout.clone(),
Layout::Struct(env.arena.alloc([
function_ptr_layout.clone(),
closure_data_layout.clone(),
])),
env.arena.alloc(stmt),
);
// define the closure data
let expr = Expr::Struct(&[]);
let closure_data_layout = Layout::Struct(&[]);
let symbols = Vec::from_iter_in(
captured_symbols.iter().map(|x| x.0),
env.arena,
)
.into_bump_slice();
let expr = Expr::Struct(symbols);
stmt = Stmt::Let(
closure_data,
expr,
closure_data_layout,
closure_data_layout.clone(),
env.arena.alloc(stmt),
);
dbg!(&stmt);
// define the function pointer
let expr = Expr::FunctionPointer(*symbol, function_layout.clone());
let expr =
Expr::FunctionPointer(*symbol, function_ptr_layout.clone());
stmt = Stmt::Let(
function_pointer,
expr,
function_layout,
function_ptr_layout,
env.arena.alloc(stmt),
);
return stmt;
} else {
procs.insert_named(
env,
layout_cache,
*symbol,
function_type,
arguments,
loc_body,
CapturedSymbols::None,
is_self_recursive,
return_type,
);
return from_can(env, cont.value, procs, layout_cache);
}
}