Merge pull request #2985 from rtfeldman/closure-called-in-defining-scope

Closure called in defining scope
This commit is contained in:
Ayaz 2022-05-02 16:41:55 -04:00 committed by GitHub
commit 5ec815a373
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 22 deletions

View file

@ -111,6 +111,11 @@ impl Symbol {
pub const fn to_ne_bytes(self) -> [u8; 8] { pub const fn to_ne_bytes(self) -> [u8; 8] {
self.0.to_ne_bytes() self.0.to_ne_bytes()
} }
#[cfg(debug_assertions)]
pub fn contains(self, needle: &str) -> bool {
format!("{:?}", self).contains(needle)
}
} }
/// Rather than displaying as this: /// Rather than displaying as this:

View file

@ -4172,7 +4172,7 @@ pub fn with_hole<'a>(
layout_cache, layout_cache,
lambda_set, lambda_set,
name, name,
symbols, symbols.iter().copied(),
assigned, assigned,
hole, hole,
) )
@ -4744,17 +4744,25 @@ fn get_specialization<'a>(
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn construct_closure_data<'a>( fn construct_closure_data<'a, I>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>, procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>, layout_cache: &mut LayoutCache<'a>,
lambda_set: LambdaSet<'a>, lambda_set: LambdaSet<'a>,
name: Symbol, name: Symbol,
symbols: &'a [&(Symbol, Variable)], symbols: I,
assigned: Symbol, assigned: Symbol,
hole: &'a Stmt<'a>, hole: &'a Stmt<'a>,
) -> Stmt<'a> { ) -> Stmt<'a>
where
I: IntoIterator<Item = &'a (Symbol, Variable)>,
I::IntoIter: ExactSizeIterator,
{
let lambda_set_layout = Layout::LambdaSet(lambda_set); let lambda_set_layout = Layout::LambdaSet(lambda_set);
let symbols = symbols.into_iter();
// arguments with a polymorphic type that we have to deal with
let mut polymorphic_arguments = Vec::new_in(env.arena);
let mut result = match lambda_set.layout_for_member(name) { let mut result = match lambda_set.layout_for_member(name) {
ClosureRepresentation::Union { ClosureRepresentation::Union {
@ -4765,10 +4773,14 @@ fn construct_closure_data<'a>(
} => { } => {
// captured variables are in symbol-alphabetic order, but now we want // captured variables are in symbol-alphabetic order, but now we want
// them ordered by their alignment requirements // them ordered by their alignment requirements
let mut combined = Vec::from_iter_in( let mut combined = Vec::with_capacity_in(symbols.len(), env.arena);
symbols.iter().map(|&&(s, _)| s).zip(field_layouts.iter()), for ((symbol, variable), layout) in symbols.zip(field_layouts.iter()) {
env.arena, if procs.partial_exprs.contains(*symbol) {
); polymorphic_arguments.push((*symbol, *variable));
}
combined.push((*symbol, layout))
}
let ptr_bytes = env.target_info; let ptr_bytes = env.target_info;
@ -4796,10 +4808,14 @@ fn construct_closure_data<'a>(
// captured variables are in symbol-alphabetic order, but now we want // captured variables are in symbol-alphabetic order, but now we want
// them ordered by their alignment requirements // them ordered by their alignment requirements
let mut combined = Vec::from_iter_in( let mut combined = Vec::with_capacity_in(symbols.len(), env.arena);
symbols.iter().map(|&(s, _)| s).zip(field_layouts.iter()), for ((symbol, variable), layout) in symbols.zip(field_layouts.iter()) {
env.arena, if procs.partial_exprs.contains(*symbol) {
); polymorphic_arguments.push((*symbol, *variable));
}
combined.push((*symbol, layout))
}
let ptr_bytes = env.target_info; let ptr_bytes = env.target_info;
@ -4811,7 +4827,7 @@ fn construct_closure_data<'a>(
}); });
let symbols = let symbols =
Vec::from_iter_in(combined.iter().map(|(a, _)| **a), env.arena).into_bump_slice(); Vec::from_iter_in(combined.iter().map(|(a, _)| *a), env.arena).into_bump_slice();
let field_layouts = let field_layouts =
Vec::from_iter_in(combined.iter().map(|(_, b)| **b), env.arena).into_bump_slice(); Vec::from_iter_in(combined.iter().map(|(_, b)| **b), env.arena).into_bump_slice();
@ -4852,10 +4868,8 @@ fn construct_closure_data<'a>(
// TODO: this is not quite right. What we should actually be doing is removing references to // TODO: this is not quite right. What we should actually be doing is removing references to
// polymorphic expressions from the captured symbols, and allowing the specializations of those // polymorphic expressions from the captured symbols, and allowing the specializations of those
// symbols to be inlined when specializing the closure body elsewhere. // symbols to be inlined when specializing the closure body elsewhere.
for &&(symbol, var) in symbols { for (symbol, var) in polymorphic_arguments {
if procs.partial_exprs.contains(symbol) { result = specialize_symbol(env, procs, layout_cache, Some(var), symbol, result, symbol);
result = specialize_symbol(env, procs, layout_cache, Some(var), symbol, result, symbol);
}
} }
result result
@ -6947,7 +6961,7 @@ fn specialize_symbol<'a>(
layout_cache, layout_cache,
lambda_set, lambda_set,
original, original,
symbols, symbols.iter().copied(),
closure_data, closure_data,
env.arena.alloc(result), env.arena.alloc(result),
) )
@ -7407,13 +7421,20 @@ fn call_by_name_help<'a>(
proc_name, proc_name,
); );
let has_closure = argument_layouts.len() != top_level_layout.arguments.len();
let closure_argument = env.unique_symbol();
if has_closure {
field_symbols.push(closure_argument);
}
let field_symbols = field_symbols.into_bump_slice(); let field_symbols = field_symbols.into_bump_slice();
let call = self::Call { let call = self::Call {
call_type: CallType::ByName { call_type: CallType::ByName {
name: proc_name, name: proc_name,
ret_layout, ret_layout,
arg_layouts: argument_layouts, arg_layouts: top_level_layout.arguments,
specialization_id: env.next_call_specialization_id(), specialization_id: env.next_call_specialization_id(),
}, },
arguments: field_symbols, arguments: field_symbols,
@ -7421,8 +7442,33 @@ fn call_by_name_help<'a>(
let result = build_call(env, call, assigned, *ret_layout, hole); let result = build_call(env, call, assigned, *ret_layout, hole);
let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev()); // NOTE: the zip omits the closure symbol, if it exists,
assign_to_symbols(env, procs, layout_cache, iter, result) // because loc_args then is shorter than field_symbols
debug_assert!([0, 1].contains(&(field_symbols.len() - loc_args.len())));
let iter = loc_args.into_iter().zip(field_symbols.iter()).rev();
let result = assign_to_symbols(env, procs, layout_cache, iter, result);
if has_closure {
let partial_proc = procs.partial_procs.get_symbol(proc_name).unwrap();
let captured = match partial_proc.captured_symbols {
CapturedSymbols::None => &[],
CapturedSymbols::Captured(slice) => slice,
};
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
proc_name,
captured.iter(),
closure_argument,
env.arena.alloc(result),
)
} else {
result
}
} }
PendingSpecializations::Making => { PendingSpecializations::Making => {
let opt_partial_proc = procs.partial_procs.symbol_to_id(proc_name); let opt_partial_proc = procs.partial_procs.symbol_to_id(proc_name);
@ -7706,7 +7752,7 @@ fn call_specialized_proc<'a>(
layout_cache, layout_cache,
lambda_set, lambda_set,
proc_name, proc_name,
symbols, symbols.iter().copied(),
closure_data_symbol, closure_data_symbol,
env.arena.alloc(new_hole), env.arena.alloc(new_hole),
); );

View file

@ -3344,3 +3344,55 @@ fn box_and_unbox_tag_union() {
(u8, u8) (u8, u8)
) )
} }
#[test]
#[cfg(any(feature = "gen-llvm"))]
fn closure_called_in_its_defining_scope() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
main : Str
main =
g : Str
g = "hello world"
getG : {} -> Str
getG = \{} -> g
getG {}
"#
),
RocStr::from("hello world"),
RocStr
)
}
#[test]
#[ignore]
#[cfg(any(feature = "gen-llvm"))]
fn issue_2894() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
main : U32
main =
g : { x : U32 }
g = { x: 1u32 }
getG : {} -> { x : U32 }
getG = \{} -> g
h : {} -> U32
h = \{} -> (getG {}).x
h {}
"#
),
1u32,
u32
)
}