Support monomorphic captures of polymorphic expressions in closures

Closes #4349
This commit is contained in:
Ayaz Hafiz 2022-10-17 13:52:33 -05:00
parent a4ed5a582d
commit ee8e718cc1
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
2 changed files with 212 additions and 98 deletions

View file

@ -847,6 +847,11 @@ struct SpecializationMark<'a> {
function_mark: Option<RawFunctionLayout<'a>>,
}
/// The deepest closure a symbol specialization was used in.
#[derive(Debug, Clone, Copy)]
#[repr(transparent)]
struct DeepestUse(Symbol);
/// When walking a function body, we may encounter specialized usages of polymorphic symbols. For
/// example
///
@ -865,73 +870,10 @@ struct SymbolSpecializations<'a>(
// 2. the number of specializations of a symbol in a def is even smaller (almost always only one)
// So, a linear VecMap is preferrable. Use a two-layered one to make (1) extraction of defs easy
// and (2) reads of a certain symbol be determined by its first occurrence, not its last.
VecMap<Symbol, VecMap<SpecializationMark<'a>, (Variable, Symbol)>>,
VecMap<Symbol, VecMap<SpecializationMark<'a>, (Variable, Symbol, DeepestUse)>>,
);
impl<'a> SymbolSpecializations<'a> {
/// Gets a specialization for a symbol, or creates a new one.
#[inline(always)]
fn get_or_insert(
&mut self,
env: &mut Env<'a, '_>,
layout_cache: &mut LayoutCache<'a>,
symbol: Symbol,
specialization_var: Variable,
) -> Symbol {
let arena = env.arena;
let subs: &Subs = env.subs;
let layout = match layout_cache.from_var(arena, specialization_var, subs) {
Ok(layout) => layout,
// This can happen when the def symbol has a type error. In such cases just use the
// def symbol, which is erroring.
Err(_) => return symbol,
};
let is_closure = matches!(
subs.get_content_without_compacting(specialization_var),
Content::Structure(FlatType::Func(..))
);
let function_mark = if is_closure {
let fn_layout = match layout_cache.raw_from_var(arena, specialization_var, subs) {
Ok(layout) => layout,
// This can happen when the def symbol has a type error. In such cases just use the
// def symbol, which is erroring.
Err(_) => return symbol,
};
Some(fn_layout)
} else {
None
};
let specialization_mark = SpecializationMark {
layout,
function_mark,
};
let symbol_specializations = self.0.get_or_insert(symbol, Default::default);
// For the first specialization, always reuse the current symbol. The vast majority of defs
// only have one instance type, so this preserves readability of the IR.
// TODO: turn me off and see what breaks.
let needs_fresh_symbol = !symbol_specializations.is_empty();
let mut make_specialized_symbol = || {
if needs_fresh_symbol {
env.unique_symbol()
} else {
symbol
}
};
let (_var, specialized_symbol) = symbol_specializations
.get_or_insert(specialization_mark, || {
(specialization_var, make_specialized_symbol())
});
*specialized_symbol
}
/// Inserts a known specialization for a symbol. Returns the overwritten specialization, if any.
pub fn get_or_insert_known(
&mut self,
@ -939,17 +881,20 @@ impl<'a> SymbolSpecializations<'a> {
mark: SpecializationMark<'a>,
specialization_var: Variable,
specialization_symbol: Symbol,
) -> Option<(Variable, Symbol)> {
self.0
.get_or_insert(symbol, Default::default)
.insert(mark, (specialization_var, specialization_symbol))
deepest_use: DeepestUse,
) -> Option<(Variable, Symbol, DeepestUse)> {
self.0.get_or_insert(symbol, Default::default).insert(
mark,
(specialization_var, specialization_symbol, deepest_use),
)
}
/// Removes all specializations for a symbol, returning the type and symbol of each specialization.
pub fn remove(
&mut self,
symbol: Symbol,
) -> impl ExactSizeIterator<Item = (SpecializationMark<'a>, (Variable, Symbol))> {
) -> impl ExactSizeIterator<Item = (SpecializationMark<'a>, (Variable, Symbol, DeepestUse))>
{
self.0
.remove(&symbol)
.map(|(_, specializations)| specializations)
@ -969,7 +914,7 @@ impl<'a> SymbolSpecializations<'a> {
symbol
);
specializations.next().map(|(_, (_, symbol))| symbol)
specializations.next().map(|(_, (_, symbol, _))| symbol)
}
pub fn is_empty(&self) -> bool {
@ -987,6 +932,30 @@ pub struct ProcsBase<'a> {
pub imported_module_thunks: &'a [Symbol],
}
/// The current set of functions under specialization. They form a stack where the latest
/// specialization to be seen is at the head of the stack.
#[derive(Clone, Debug)]
struct SpecializationStack<'a>(Vec<'a, Symbol>);
impl<'a> SpecializationStack<'a> {
fn current_deepest_lambda(&self) -> DeepestUse {
DeepestUse(*self.0.last().unwrap())
}
fn is_nested_closure(&self, inner: Symbol, outer: Symbol) -> Option<bool> {
let mut seen_outer = false;
for &fun in self.0.iter() {
if fun == outer {
seen_outer = true;
}
if fun == inner {
return Some(seen_outer);
}
}
return None;
}
}
#[derive(Clone, Debug)]
pub struct Procs<'a> {
pub partial_procs: PartialProcs<'a>,
@ -998,8 +967,7 @@ pub struct Procs<'a> {
pub runtime_errors: BumpMap<Symbol, &'a str>,
pub externals_we_need: BumpMap<ModuleId, ExternalSpecializations<'a>>,
symbol_specializations: SymbolSpecializations<'a>,
/// The current set of functions under specialization.
pub specialization_stack: Vec<'a, Symbol>,
specialization_stack: SpecializationStack<'a>,
}
impl<'a> Procs<'a> {
@ -1014,17 +982,18 @@ impl<'a> Procs<'a> {
runtime_errors: BumpMap::new_in(arena),
externals_we_need: BumpMap::new_in(arena),
symbol_specializations: Default::default(),
specialization_stack: Vec::with_capacity_in(16, arena),
specialization_stack: SpecializationStack(Vec::with_capacity_in(16, arena)),
}
}
fn push_active_specialization(&mut self, specialization: Symbol) {
self.specialization_stack.push(specialization);
self.specialization_stack.0.push(specialization);
}
fn pop_active_specialization(&mut self, specialization: Symbol) {
let popped = self
.specialization_stack
.0
.pop()
.expect("specialization stack is empty");
debug_assert_eq!(
@ -1049,7 +1018,7 @@ impl<'a> Procs<'a> {
/// specialize both `foo : Str False -> Str` and `foo : {} False -> Str` at the same time, so
/// the latter specialization must be deferred.
fn symbol_needs_suspended_specialization(&self, specialization: Symbol) -> bool {
self.specialization_stack.contains(&specialization)
self.specialization_stack.0.contains(&specialization)
}
}
@ -1368,6 +1337,100 @@ impl<'a> Procs<'a> {
}
}
}
/// Gets a specialization for a symbol, or creates a new one.
#[inline(always)]
fn get_or_insert_symbol_specialization(
&mut self,
env: &mut Env<'a, '_>,
layout_cache: &mut LayoutCache<'a>,
symbol: Symbol,
specialization_var: Variable,
) -> Symbol {
let arena = env.arena;
let subs: &Subs = env.subs;
let layout = match layout_cache.from_var(arena, specialization_var, subs) {
Ok(layout) => layout,
// This can happen when the def symbol has a type error. In such cases just use the
// def symbol, which is erroring.
Err(_) => return symbol,
};
let is_closure = matches!(
subs.get_content_without_compacting(specialization_var),
Content::Structure(FlatType::Func(..))
);
let function_mark = if is_closure {
let fn_layout = match layout_cache.raw_from_var(arena, specialization_var, subs) {
Ok(layout) => layout,
// This can happen when the def symbol has a type error. In such cases just use the
// def symbol, which is erroring.
Err(_) => return symbol,
};
Some(fn_layout)
} else {
None
};
let specialization_mark = SpecializationMark {
layout,
function_mark,
};
let symbol_specializations = self
.symbol_specializations
.0
.get_or_insert(symbol, Default::default);
// For the first specialization, always reuse the current symbol. The vast majority of defs
// only have one instance type, so this preserves readability of the IR.
// TODO: turn me off and see what breaks.
let needs_fresh_symbol = !symbol_specializations.is_empty();
let mut make_specialized_symbol = || {
if needs_fresh_symbol {
env.unique_symbol()
} else {
symbol
}
};
let current_use = self.specialization_stack.current_deepest_lambda();
let (_var, specialized_symbol, deepest_use) = symbol_specializations
.get_or_insert(specialization_mark, || {
(specialization_var, make_specialized_symbol(), current_use)
});
if matches!(
self.specialization_stack
.is_nested_closure(current_use.0, deepest_use.0),
Some(true) | None
) {
*deepest_use = current_use;
}
*specialized_symbol
}
pub fn get_symbol_specializations_used_in(
&self,
lambda: LambdaName<'_>,
symbol: Symbol,
) -> Option<impl Iterator<Item = (Variable, Symbol)> + '_> {
let lambda_name = lambda.name();
self.symbol_specializations.0.get(&symbol).map(move |l| {
l.iter().filter_map(move |(_, (var, sym, deepest_use))| {
match self
.specialization_stack
.is_nested_closure(deepest_use.0, lambda_name)
{
None | Some(true) => Some((*var, *sym)),
Some(false) => None,
}
})
})
}
}
#[derive(Default)]
@ -2503,7 +2566,7 @@ fn from_can_let<'a>(
// We do need specializations
1 => {
let (_specialization_mark, (var, specialized_symbol)) =
let (_specialization_mark, (var, specialized_symbol, _deepest_use)) =
needed_specializations.next().unwrap();
// Make sure rigid variables in the annotation are converted to flex variables.
@ -2534,7 +2597,7 @@ fn from_can_let<'a>(
// Need to eat the cost and create a specialized version of the body for
// each specialization.
for (_specialization_mark, (var, specialized_symbol)) in
for (_specialization_mark, (var, specialized_symbol, _deepest_use)) in
needed_specializations
{
use roc_can::copy::deep_copy_type_vars_into_expr;
@ -3388,11 +3451,23 @@ fn specialize_proc_help<'a>(
// An argument from the closure list may have taken on a specialized symbol
// name during the evaluation of the def body. If this is the case, load the
// specialized name rather than the original captured name!
let mut get_specialized_name = |symbol| {
procs
.symbol_specializations
.remove_single(symbol)
.unwrap_or(symbol)
let get_specialized_name = |symbol| {
let specs_used_in_body =
procs.get_symbol_specializations_used_in(lambda_name, symbol);
match specs_used_in_body {
Some(mut specs) => {
let spec_symbol =
specs.next().map(|(_, sym)| sym).unwrap_or(symbol);
if specs.next().is_some() {
internal_error!(
"polymorphic symbol captures not supported yet"
);
}
spec_symbol
}
None => symbol,
}
};
match closure_layout
@ -4083,9 +4158,7 @@ pub fn with_hole<'a>(
variable,
) {
let real_symbol =
procs
.symbol_specializations
.get_or_insert(env, layout_cache, symbol, variable);
procs.get_or_insert_symbol_specialization(env, layout_cache, symbol, variable);
symbol = real_symbol;
}
@ -4182,7 +4255,7 @@ pub fn with_hole<'a>(
match can_reuse_symbol(env, procs, &loc_arg_expr.value, arg_var) {
// Opaques decay to their argument.
ReuseSymbol::Value(symbol) => {
let real_name = procs.symbol_specializations.get_or_insert(
let real_name = procs.get_or_insert_symbol_specialization(
env,
layout_cache,
symbol,
@ -4247,7 +4320,7 @@ pub fn with_hole<'a>(
can_fields.push(Field::FunctionOrUnspecialized(symbol, variable));
}
Value(symbol) => {
let reusable = procs.symbol_specializations.get_or_insert(
let reusable = procs.get_or_insert_symbol_specialization(
env,
layout_cache,
symbol,
@ -4722,6 +4795,7 @@ pub fn with_hole<'a>(
find_lambda_name(env, layout_cache, lambda_set, name, &[]);
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
lambda_name,
@ -4775,6 +4849,7 @@ pub fn with_hole<'a>(
find_lambda_name(env, layout_cache, lambda_set, name, &[]);
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
lambda_name,
@ -5023,6 +5098,7 @@ pub fn with_hole<'a>(
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
lambda_name,
@ -5169,7 +5245,7 @@ pub fn with_hole<'a>(
}
}
Value(function_symbol) => {
let function_symbol = procs.symbol_specializations.get_or_insert(
let function_symbol = procs.get_or_insert_symbol_specialization(
env,
layout_cache,
function_symbol,
@ -5585,7 +5661,8 @@ where
#[allow(clippy::too_many_arguments)]
fn construct_closure_data<'a, I>(
env: &mut Env<'a, '_>,
layout_cache: &LayoutCache<'a>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
lambda_set: LambdaSet<'a>,
name: LambdaName<'a>,
symbols: I,
@ -5670,12 +5747,19 @@ where
debug_assert_eq!(symbols.len(), 1);
let mut symbols = symbols;
let (captured_symbol, _) = symbols.next().unwrap();
let (captured_symbol, captured_var) = symbols.next().unwrap();
let captured_symbol = procs.get_or_insert_symbol_specialization(
env,
layout_cache,
*captured_symbol,
*captured_var,
);
// The capture set is unwrapped, so just replaced the assigned capture symbol with the
// only capture.
let mut hole = hole.clone();
substitute_in_exprs(env.arena, &mut hole, assigned, *captured_symbol);
substitute_in_exprs(env.arena, &mut hole, assigned, captured_symbol);
hole
}
ClosureRepresentation::EnumDispatch(repr) => match repr {
@ -6056,6 +6140,7 @@ fn tag_union_to_function<'a>(
debug_assert!(lambda_name.no_captures());
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
lambda_name,
@ -7580,9 +7665,7 @@ fn possible_reuse_symbol_or_specialize<'a>(
) -> Symbol {
match can_reuse_symbol(env, procs, expr, var) {
ReuseSymbol::Value(symbol) => {
procs
.symbol_specializations
.get_or_insert(env, layout_cache, symbol, var)
procs.get_or_insert_symbol_specialization(env, layout_cache, symbol, var)
}
_ => env.unique_symbol(),
}
@ -7658,13 +7741,17 @@ where
// captured symbols can only ever be specialized outside the closure.
// After that is done, remove this hack.
.chain(if no_specializations_needed {
[Some((variable, left))]
[Some((
variable,
left,
procs.specialization_stack.current_deepest_lambda(),
))]
} else {
[None]
})
.flatten();
for (variable, left) in needed_specializations_of_left {
for (variable, left, _deepest_use) in needed_specializations_of_left {
add_needed_external(procs, env, variable, LambdaName::no_niche(right));
let res_layout = layout_cache.from_var(env.arena, variable, env.subs);
@ -7688,7 +7775,7 @@ where
let left_had_specialization_symbols = needed_specializations_of_left.len() > 0;
for (specialization_mark, (specialized_var, specialized_sym)) in
for (specialization_mark, (specialized_var, specialized_sym, deepest_use)) in
needed_specializations_of_left
{
let old_specialized_sym = procs.symbol_specializations.get_or_insert_known(
@ -7696,9 +7783,10 @@ where
specialization_mark,
specialized_var,
specialized_sym,
deepest_use,
);
if let Some((_, old_specialized_sym)) = old_specialized_sym {
if let Some((_, old_specialized_sym, _)) = old_specialized_sym {
scratchpad_update_specializations.push((old_specialized_sym, specialized_sym));
}
}
@ -7881,6 +7969,7 @@ fn specialize_symbol<'a>(
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
lambda_name,
@ -7932,6 +8021,7 @@ fn specialize_symbol<'a>(
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
lambda_name,
@ -8337,6 +8427,7 @@ fn call_by_name_help<'a>(
construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
proc_name,
@ -8733,6 +8824,7 @@ fn call_specialized_proc<'a>(
let result = construct_closure_data(
env,
procs,
layout_cache,
lambda_set,
proc_name,

View file

@ -4103,3 +4103,25 @@ fn issue_4348() {
RocStr
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn issue_4349() {
assert_evals_to!(
indoc!(
r#"
ir = Ok ""
res =
Result.try ir \_ ->
when ir is
Ok "" -> Ok ""
_ -> Err Bad
when res is
Ok _ -> "okay"
_ -> "FAIL"
"#
),
RocStr::from("okay"),
RocStr
);
}