Merge pull request #5628 from roc-lang/i5617

Do not drop uninhabited captures from lambda sets
This commit is contained in:
Ayaz 2023-07-01 13:15:14 -05:00 committed by GitHub
commit f2e013a4e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 262 additions and 14 deletions

View file

@ -1914,6 +1914,11 @@ impl<'a> LambdaSet<'a> {
) -> Cacheable<InLayout<'a>> {
let union_labels = UnsortedUnionLabels { tags: set };
// Even if a variant in the lambda set has uninhabitable captures (and is hence
// unreachable as a function), we want to keep it in the representation. Failing to do so
// risks dropping relevant specializations needed during monomorphization.
let drop_uninhabited_variants = DropUninhabitedVariants(false);
match opt_rec_var {
Some(rec_var) => {
let Cacheable(result, criteria) =
@ -1922,7 +1927,7 @@ impl<'a> LambdaSet<'a> {
Cacheable(result, criteria)
}
None => layout_from_non_recursive_union(env, &union_labels),
None => layout_from_non_recursive_union(env, &union_labels, drop_uninhabited_variants),
}
}
@ -3330,7 +3335,7 @@ fn layout_from_flat_type<'a>(
debug_assert!(ext_var_is_empty_tag_union(subs, ext_var));
layout_from_non_recursive_union(env, &tags).map(Ok)
layout_from_non_recursive_union(env, &tags, DropUninhabitedVariants(true)).map(Ok)
}
FunctionOrTagUnion(tag_names, _, ext_var) => {
debug_assert!(
@ -3343,7 +3348,8 @@ fn layout_from_flat_type<'a>(
tags: tag_names.iter().map(|t| (t, &[] as &[Variable])).collect(),
};
layout_from_non_recursive_union(env, &unsorted_tags).map(Ok)
layout_from_non_recursive_union(env, &unsorted_tags, DropUninhabitedVariants(true))
.map(Ok)
}
RecursiveTagUnion(rec_var, tags, ext_var) => {
let (tags, ext_var) = tags.unsorted_tags_and_ext(subs, ext_var);
@ -3621,11 +3627,14 @@ pub fn union_sorted_tags<'a>(
var
};
let drop_uninhabited_variants = DropUninhabitedVariants(true);
let mut tags_vec = std::vec::Vec::new();
let result = match roc_types::pretty_print::chase_ext_tag_union(env.subs, var, &mut tags_vec) {
ChasedExt::Empty => {
let opt_rec_var = get_recursion_var(env.subs, var);
let Cacheable(result, _) = union_sorted_tags_help(env, tags_vec, opt_rec_var);
let Cacheable(result, _) =
union_sorted_tags_help(env, tags_vec, opt_rec_var, drop_uninhabited_variants);
result
}
ChasedExt::NonEmpty { content, .. } => {
@ -3638,12 +3647,22 @@ pub fn union_sorted_tags<'a>(
// x
// In such cases it's fine to drop the variable. We may be proven wrong in the future...
let opt_rec_var = get_recursion_var(env.subs, var);
let Cacheable(result, _) = union_sorted_tags_help(env, tags_vec, opt_rec_var);
let Cacheable(result, _) = union_sorted_tags_help(
env,
tags_vec,
opt_rec_var,
drop_uninhabited_variants,
);
result
}
RecursionVar { .. } => {
let opt_rec_var = get_recursion_var(env.subs, var);
let Cacheable(result, _) = union_sorted_tags_help(env, tags_vec, opt_rec_var);
let Cacheable(result, _) = union_sorted_tags_help(
env,
tags_vec,
opt_rec_var,
drop_uninhabited_variants,
);
result
}
@ -3694,9 +3713,12 @@ impl Label for Symbol {
}
}
struct DropUninhabitedVariants(bool);
fn union_sorted_non_recursive_tags_help<'a, L>(
env: &mut Env<'a, '_>,
tags_list: &mut Vec<'_, &'_ (&'_ L, &[Variable])>,
drop_uninhabited_variants: DropUninhabitedVariants,
) -> Cacheable<UnionVariant<'a>>
where
L: Label + Ord + Clone + Into<TagOrClosure>,
@ -3816,7 +3838,7 @@ where
answer.push((tag_name.clone().into(), arg_layouts.into_bump_slice()));
}
if inhabited_tag_ids.count_ones() == 1 {
if inhabited_tag_ids.count_ones() == 1 && drop_uninhabited_variants.0 {
let kept_tag_id = inhabited_tag_ids.first_one().unwrap();
let kept = answer.get(kept_tag_id).unwrap();
@ -3869,13 +3891,14 @@ pub fn union_sorted_tags_pub<'a, L>(
where
L: Into<TagOrClosure> + Ord + Clone,
{
union_sorted_tags_help(env, tags_vec, opt_rec_var).value()
union_sorted_tags_help(env, tags_vec, opt_rec_var, DropUninhabitedVariants(true)).value()
}
fn union_sorted_tags_help<'a, L>(
env: &mut Env<'a, '_>,
mut tags_vec: std::vec::Vec<(L, std::vec::Vec<Variable>)>,
opt_rec_var: Option<Variable>,
drop_uninhabited_variants: DropUninhabitedVariants,
) -> Cacheable<UnionVariant<'a>>
where
L: Into<TagOrClosure> + Ord + Clone,
@ -4028,7 +4051,7 @@ where
answer.push((tag_name.into(), arg_layouts.into_bump_slice()));
}
if inhabited_tag_ids.count_ones() == 1 && !is_recursive {
if inhabited_tag_ids.count_ones() == 1 && !is_recursive && drop_uninhabited_variants.0 {
let kept_tag_id = inhabited_tag_ids.first_one().unwrap();
let kept = answer.get(kept_tag_id).unwrap();
@ -4131,6 +4154,7 @@ fn layout_from_newtype<'a, L: Label>(
fn layout_from_non_recursive_union<'a, L>(
env: &mut Env<'a, '_>,
tags: &UnsortedUnionLabels<L>,
drop_uninhabited_variants: DropUninhabitedVariants,
) -> Cacheable<InLayout<'a>>
where
L: Label + Ord + Into<TagOrClosure>,
@ -4146,7 +4170,8 @@ where
let mut criteria = CACHEABLE;
let variant =
union_sorted_non_recursive_tags_help(env, &mut tags_vec).decompose(&mut criteria, env.subs);
union_sorted_non_recursive_tags_help(env, &mut tags_vec, drop_uninhabited_variants)
.decompose(&mut criteria, env.subs);
let compute_semantic = || L::semantic_repr(env.arena, tags_vec.iter().map(|(l, _)| *l));