Unify lambda sets with left/right closure capture differences

This commit is contained in:
Ayaz Hafiz 2022-06-28 16:45:46 -04:00 committed by ayazhafiz
parent 5f8b509cb3
commit 88618c098d
No known key found for this signature in database
GPG key ID: B443F7A3030C9AED
2 changed files with 159 additions and 67 deletions

View file

@ -966,6 +966,129 @@ fn extract_specialization_lambda_set<M: MetaCollector>(
outcome
}
#[derive(Debug)]
struct Sides {
left: Vec<(Symbol, VariableSubsSlice)>,
right: Vec<(Symbol, VariableSubsSlice)>,
}
impl Default for Sides {
fn default() -> Self {
Self {
left: Vec::with_capacity(1),
right: Vec::with_capacity(1),
}
}
}
fn separate_union_lambdas(
subs: &mut Subs,
pool: &mut Pool,
fields1: UnionLambdas,
fields2: UnionLambdas,
) -> (
Vec<(Symbol, VariableSubsSlice)>,
Vec<(Symbol, VariableSubsSlice)>,
Vec<(Symbol, VariableSubsSlice)>,
) {
debug_assert!(fields1.is_sorted(subs));
debug_assert!(fields2.is_sorted(subs));
// lambda names -> (the captures for that lambda on the left side, the captures for that lambda on the right side)
// e.g. [[F1 U8], [F1 U64], [F2 a]] ~ [[F1 Str], [F2 Str]] becomes
// F1 -> { left: [ [U8], [U64] ], right: [ [Str] ] }
// F2 -> { left: [ [a] ], right: [ [Str] ] }
let mut buckets: VecMap<Symbol, Sides> = VecMap::with_capacity(fields1.len() + fields2.len());
for (sym, vars) in fields1.iter_all() {
let bucket = buckets.get_or_insert(subs[sym], Sides::default);
bucket.left.push((subs[sym], subs[vars]));
}
for (sym, vars) in fields2.iter_all() {
let bucket = buckets.get_or_insert(subs[sym], Sides::default);
bucket.right.push((subs[sym], subs[vars]));
}
let mut only_in_left = Vec::with_capacity(fields1.len());
let mut only_in_right = Vec::with_capacity(fields2.len());
let mut joined = Vec::with_capacity(fields1.len() + fields2.len());
for (lambda_name, Sides { left, mut right }) in buckets {
match (left.as_slice(), right.as_slice()) {
(&[], &[]) => internal_error!("somehow both are empty but there's an entry?"),
(&[], _) => only_in_right.extend(right),
(_, &[]) => only_in_left.extend(left),
(_, _) => {
'next_left: for (_, left_slice) in left {
// Does the current slice on the left unify with a slice on the right?
//
// If yes, we unify then and the unified result to `joined`.
//
// Otherwise if no such slice on the right is found, then the slice on the `left` has no slice,
// either on the left or right, it unifies with (since the left was constructed
// inductively via the same procedure).
//
// At the end each slice in the left and right has been explored, so
// - `joined` contains all the slices that can unify
// - left contains unique captures slices that will unify with no other slice
// - right contains unique captures slices that will unify with no other slice
//
// Note also if a slice l on the left and a slice r on the right unify, there
// is no other r' != r on the right such that l ~ r', and respectively there is
// no other l' != l on the left such that l' ~ r. Otherwise, it must be that l ~ l'
// (resp. r ~ r'), but then l = l' (resp. r = r'), and they would have become the same
// slice in a previous call to `separate_union_lambdas`.
'try_next_right: for (right_index, (_, right_slice)) in right.iter().enumerate()
{
if left_slice.len() != right_slice.len() {
continue 'try_next_right;
}
let snapshot = subs.snapshot();
for (var1, var2) in (left_slice.into_iter()).zip(right_slice.into_iter()) {
let (var1, var2) = (subs[var1], subs[var2]);
// Lambda sets are effectively tags under another name, and their usage can also result
// in the arguments of a lambda name being recursive. It very well may happen that
// during unification, a lambda set previously marked as not recursive becomes
// recursive. See the docs of [LambdaSet] for one example, or https://github.com/rtfeldman/roc/pull/2307.
//
// Like with tag unions, if it has, we'll always pass through this branch. So, take
// this opportunity to promote the lambda set to recursive if need be.
maybe_mark_union_recursive(subs, var1);
maybe_mark_union_recursive(subs, var2);
let outcome =
unify_pool::<NoCollector>(subs, pool, var1, var2, Mode::EQ);
if !outcome.mismatches.is_empty() {
subs.rollback_to(snapshot);
continue 'try_next_right;
}
}
// All the variables unified, so we can join the left + right.
// The variables are unified in left and right slice, so just reuse the left slice.
joined.push((lambda_name, left_slice));
// Remove the right slice, it unifies with the left so this is its unique
// unification.
right.swap_remove(right_index);
continue 'next_left;
}
// No slice on the right unified with the left, so the slice on the left is on
// its own.
only_in_left.push((lambda_name, left_slice));
}
// Possible that there are items left over in the right, they are on their own.
only_in_right.extend(right);
}
}
}
(only_in_left, only_in_right, joined)
}
fn unify_lambda_set_help<M: MetaCollector>(
subs: &mut Subs,
pool: &mut Pool,
@ -994,65 +1117,20 @@ fn unify_lambda_set_help<M: MetaCollector>(
"Recursion var is present, but it doesn't have a recursive content!"
);
let Separate {
only_in_1,
only_in_2,
in_both,
} = separate_union_lambdas(subs, solved1, solved2);
let (only_in_1, only_in_2, in_both) = separate_union_lambdas(subs, pool, solved1, solved2);
let mut new_lambdas = vec![];
for (lambda_name, (vars1, vars2)) in in_both {
let mut captures_unify = vars1.len() == vars2.len();
let all_lambdas = in_both
.into_iter()
.map(|(name, slice)| (name, subs.get_subs_slice(slice).to_vec()));
if captures_unify {
for (var1, var2) in (vars1.into_iter()).zip(vars2.into_iter()) {
let (var1, var2) = (subs[var1], subs[var2]);
// Lambda sets are effectively tags under another name, and their usage can also result
// in the arguments of a lambda name being recursive. It very well may happen that
// during unification, a lambda set previously marked as not recursive becomes
// recursive. See the docs of [LambdaSet] for one example, or https://github.com/rtfeldman/roc/pull/2307.
//
// Like with tag unions, if it has, we'll always pass through this branch. So, take
// this opportunity to promote the lambda set to recursive if need be.
maybe_mark_union_recursive(subs, var1);
maybe_mark_union_recursive(subs, var2);
let snapshot = subs.snapshot();
let outcome = unify_pool::<M>(subs, pool, var1, var2, ctx.mode);
if outcome.mismatches.is_empty() {
subs.commit_snapshot(snapshot);
} else {
captures_unify = false;
subs.rollback_to(snapshot);
// Continue so the other variables can unify if possible, allowing us to re-use
// shared variables.
}
}
}
if captures_unify {
new_lambdas.push((lambda_name, subs.get_subs_slice(vars1).to_vec()));
} else {
debug_assert!((subs.get_subs_slice(vars1).iter())
.zip(subs.get_subs_slice(vars2).iter())
.any(|(v1, v2)| !subs.equivalent_without_compacting(*v1, *v2)));
new_lambdas.push((lambda_name, subs.get_subs_slice(vars1).to_vec()));
new_lambdas.push((lambda_name, subs.get_subs_slice(vars2).to_vec()));
}
}
let all_lambdas = new_lambdas;
let all_lambdas = merge_sorted(
let all_lambdas = merge_sorted_preserving_duplicates(
all_lambdas,
only_in_1.into_iter().map(|(name, subs_slice)| {
let vec = subs.get_subs_slice(subs_slice).to_vec();
(name, vec)
}),
);
let all_lambdas = merge_sorted(
let all_lambdas = merge_sorted_preserving_duplicates(
all_lambdas,
only_in_2.into_iter().map(|(name, subs_slice)| {
let vec = subs.get_subs_slice(subs_slice).to_vec();
@ -1396,7 +1474,7 @@ struct Separate<K, V> {
in_both: Vec<(K, (V, V))>,
}
fn merge_sorted<K, V, I1, I2>(input1: I1, input2: I2) -> Vec<(K, V)>
fn merge_sorted_help<K, V, I1, I2>(input1: I1, input2: I2, preserve_duplicates: bool) -> Vec<(K, V)>
where
K: Ord,
I1: IntoIterator<Item = (K, V)>,
@ -1426,8 +1504,11 @@ where
}
Some(Ordering::Equal) => {
let (k, v) = it1.next().unwrap();
let (_, _) = it2.next().unwrap();
let (k2, v2) = it2.next().unwrap();
result.push((k, v));
if preserve_duplicates {
result.push((k2, v2));
}
}
Some(Ordering::Greater) => {
result.push(it2.next().unwrap());
@ -1439,6 +1520,24 @@ where
result
}
fn merge_sorted<K, V, I1, I2>(input1: I1, input2: I2) -> Vec<(K, V)>
where
K: Ord,
I1: IntoIterator<Item = (K, V)>,
I2: IntoIterator<Item = (K, V)>,
{
merge_sorted_help(input1, input2, false)
}
fn merge_sorted_preserving_duplicates<K, V, I1, I2>(input1: I1, input2: I2) -> Vec<(K, V)>
where
K: Ord,
I1: IntoIterator<Item = (K, V)>,
I2: IntoIterator<Item = (K, V)>,
{
merge_sorted_help(input1, input2, true)
}
fn separate<K, V, I1, I2>(input1: I1, input2: I2) -> Separate<K, V>
where
K: Ord,
@ -1501,19 +1600,6 @@ fn separate_union_tags(
(separate(it1, it2), new_ext1, new_ext2)
}
fn separate_union_lambdas(
subs: &Subs,
fields1: UnionLambdas,
fields2: UnionLambdas,
) -> Separate<Symbol, VariableSubsSlice> {
debug_assert!(fields1.is_sorted(subs));
debug_assert!(fields2.is_sorted(subs));
let it1 = fields1.iter_all().map(|(s, vars)| (subs[s], subs[vars]));
let it2 = fields2.iter_all().map(|(s, vars)| (subs[s], subs[vars]));
separate(it1, it2)
}
#[derive(Debug, Copy, Clone)]
enum Rec {
None,