From 88618c098d5e353bf0ac29fcf749723178521a62 Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Tue, 28 Jun 2022 16:45:46 -0400 Subject: [PATCH] Unify lambda sets with left/right closure capture differences --- crates/compiler/solve/tests/solve_expr.rs | 8 +- crates/compiler/unify/src/unify.rs | 218 +++++++++++++++------- 2 files changed, 159 insertions(+), 67 deletions(-) diff --git a/crates/compiler/solve/tests/solve_expr.rs b/crates/compiler/solve/tests/solve_expr.rs index 9b83cc8233..a0200c0dd9 100644 --- a/crates/compiler/solve/tests/solve_expr.rs +++ b/crates/compiler/solve/tests/solve_expr.rs @@ -6974,12 +6974,18 @@ mod solve_expr { fun = when x is True -> capture "" + # ^^^^^^^ False -> capture {} + # ^^^^^^^ fun #^^^{-1} "# ), - &["fun : {} -[[thunk(5) {}, thunk(5) Str]]-> Str"], + &[ + "capture : Str -[[capture(1)]]-> ({} -[[thunk(5) {}, thunk(5) Str]]-> Str)", + "capture : {} -[[capture(1)]]-> ({} -[[thunk(5) {}, thunk(5) Str]]-> Str)", + "fun : {} -[[thunk(5) {}, thunk(5) Str]]-> Str", + ] ); } diff --git a/crates/compiler/unify/src/unify.rs b/crates/compiler/unify/src/unify.rs index fb525f2c8e..58c0651208 100644 --- a/crates/compiler/unify/src/unify.rs +++ b/crates/compiler/unify/src/unify.rs @@ -966,6 +966,129 @@ fn extract_specialization_lambda_set( outcome } +#[derive(Debug)] +struct Sides { + left: Vec<(Symbol, VariableSubsSlice)>, + right: Vec<(Symbol, VariableSubsSlice)>, +} + +impl Default for Sides { + fn default() -> Self { + Self { + left: Vec::with_capacity(1), + right: Vec::with_capacity(1), + } + } +} + +fn separate_union_lambdas( + subs: &mut Subs, + pool: &mut Pool, + fields1: UnionLambdas, + fields2: UnionLambdas, +) -> ( + Vec<(Symbol, VariableSubsSlice)>, + Vec<(Symbol, VariableSubsSlice)>, + Vec<(Symbol, VariableSubsSlice)>, +) { + debug_assert!(fields1.is_sorted(subs)); + debug_assert!(fields2.is_sorted(subs)); + + // lambda names -> (the captures for that lambda on the left side, the captures for that lambda on the right side) + // e.g. [[F1 U8], [F1 U64], [F2 a]] ~ [[F1 Str], [F2 Str]] becomes + // F1 -> { left: [ [U8], [U64] ], right: [ [Str] ] } + // F2 -> { left: [ [a] ], right: [ [Str] ] } + let mut buckets: VecMap = VecMap::with_capacity(fields1.len() + fields2.len()); + + for (sym, vars) in fields1.iter_all() { + let bucket = buckets.get_or_insert(subs[sym], Sides::default); + bucket.left.push((subs[sym], subs[vars])); + } + for (sym, vars) in fields2.iter_all() { + let bucket = buckets.get_or_insert(subs[sym], Sides::default); + bucket.right.push((subs[sym], subs[vars])); + } + + let mut only_in_left = Vec::with_capacity(fields1.len()); + let mut only_in_right = Vec::with_capacity(fields2.len()); + let mut joined = Vec::with_capacity(fields1.len() + fields2.len()); + for (lambda_name, Sides { left, mut right }) in buckets { + match (left.as_slice(), right.as_slice()) { + (&[], &[]) => internal_error!("somehow both are empty but there's an entry?"), + (&[], _) => only_in_right.extend(right), + (_, &[]) => only_in_left.extend(left), + (_, _) => { + 'next_left: for (_, left_slice) in left { + // Does the current slice on the left unify with a slice on the right? + // + // If yes, we unify then and the unified result to `joined`. + // + // Otherwise if no such slice on the right is found, then the slice on the `left` has no slice, + // either on the left or right, it unifies with (since the left was constructed + // inductively via the same procedure). + // + // At the end each slice in the left and right has been explored, so + // - `joined` contains all the slices that can unify + // - left contains unique captures slices that will unify with no other slice + // - right contains unique captures slices that will unify with no other slice + // + // Note also if a slice l on the left and a slice r on the right unify, there + // is no other r' != r on the right such that l ~ r', and respectively there is + // no other l' != l on the left such that l' ~ r. Otherwise, it must be that l ~ l' + // (resp. r ~ r'), but then l = l' (resp. r = r'), and they would have become the same + // slice in a previous call to `separate_union_lambdas`. + 'try_next_right: for (right_index, (_, right_slice)) in right.iter().enumerate() + { + if left_slice.len() != right_slice.len() { + continue 'try_next_right; + } + + let snapshot = subs.snapshot(); + for (var1, var2) in (left_slice.into_iter()).zip(right_slice.into_iter()) { + let (var1, var2) = (subs[var1], subs[var2]); + + // Lambda sets are effectively tags under another name, and their usage can also result + // in the arguments of a lambda name being recursive. It very well may happen that + // during unification, a lambda set previously marked as not recursive becomes + // recursive. See the docs of [LambdaSet] for one example, or https://github.com/rtfeldman/roc/pull/2307. + // + // Like with tag unions, if it has, we'll always pass through this branch. So, take + // this opportunity to promote the lambda set to recursive if need be. + maybe_mark_union_recursive(subs, var1); + maybe_mark_union_recursive(subs, var2); + + let outcome = + unify_pool::(subs, pool, var1, var2, Mode::EQ); + + if !outcome.mismatches.is_empty() { + subs.rollback_to(snapshot); + continue 'try_next_right; + } + } + + // All the variables unified, so we can join the left + right. + // The variables are unified in left and right slice, so just reuse the left slice. + joined.push((lambda_name, left_slice)); + // Remove the right slice, it unifies with the left so this is its unique + // unification. + right.swap_remove(right_index); + continue 'next_left; + } + + // No slice on the right unified with the left, so the slice on the left is on + // its own. + only_in_left.push((lambda_name, left_slice)); + } + + // Possible that there are items left over in the right, they are on their own. + only_in_right.extend(right); + } + } + } + + (only_in_left, only_in_right, joined) +} + fn unify_lambda_set_help( subs: &mut Subs, pool: &mut Pool, @@ -994,65 +1117,20 @@ fn unify_lambda_set_help( "Recursion var is present, but it doesn't have a recursive content!" ); - let Separate { - only_in_1, - only_in_2, - in_both, - } = separate_union_lambdas(subs, solved1, solved2); + let (only_in_1, only_in_2, in_both) = separate_union_lambdas(subs, pool, solved1, solved2); - let mut new_lambdas = vec![]; - for (lambda_name, (vars1, vars2)) in in_both { - let mut captures_unify = vars1.len() == vars2.len(); + let all_lambdas = in_both + .into_iter() + .map(|(name, slice)| (name, subs.get_subs_slice(slice).to_vec())); - if captures_unify { - for (var1, var2) in (vars1.into_iter()).zip(vars2.into_iter()) { - let (var1, var2) = (subs[var1], subs[var2]); - - // Lambda sets are effectively tags under another name, and their usage can also result - // in the arguments of a lambda name being recursive. It very well may happen that - // during unification, a lambda set previously marked as not recursive becomes - // recursive. See the docs of [LambdaSet] for one example, or https://github.com/rtfeldman/roc/pull/2307. - // - // Like with tag unions, if it has, we'll always pass through this branch. So, take - // this opportunity to promote the lambda set to recursive if need be. - maybe_mark_union_recursive(subs, var1); - maybe_mark_union_recursive(subs, var2); - - let snapshot = subs.snapshot(); - let outcome = unify_pool::(subs, pool, var1, var2, ctx.mode); - - if outcome.mismatches.is_empty() { - subs.commit_snapshot(snapshot); - } else { - captures_unify = false; - subs.rollback_to(snapshot); - // Continue so the other variables can unify if possible, allowing us to re-use - // shared variables. - } - } - } - - if captures_unify { - new_lambdas.push((lambda_name, subs.get_subs_slice(vars1).to_vec())); - } else { - debug_assert!((subs.get_subs_slice(vars1).iter()) - .zip(subs.get_subs_slice(vars2).iter()) - .any(|(v1, v2)| !subs.equivalent_without_compacting(*v1, *v2))); - - new_lambdas.push((lambda_name, subs.get_subs_slice(vars1).to_vec())); - new_lambdas.push((lambda_name, subs.get_subs_slice(vars2).to_vec())); - } - } - - let all_lambdas = new_lambdas; - let all_lambdas = merge_sorted( + let all_lambdas = merge_sorted_preserving_duplicates( all_lambdas, only_in_1.into_iter().map(|(name, subs_slice)| { let vec = subs.get_subs_slice(subs_slice).to_vec(); (name, vec) }), ); - let all_lambdas = merge_sorted( + let all_lambdas = merge_sorted_preserving_duplicates( all_lambdas, only_in_2.into_iter().map(|(name, subs_slice)| { let vec = subs.get_subs_slice(subs_slice).to_vec(); @@ -1396,7 +1474,7 @@ struct Separate { in_both: Vec<(K, (V, V))>, } -fn merge_sorted(input1: I1, input2: I2) -> Vec<(K, V)> +fn merge_sorted_help(input1: I1, input2: I2, preserve_duplicates: bool) -> Vec<(K, V)> where K: Ord, I1: IntoIterator, @@ -1426,8 +1504,11 @@ where } Some(Ordering::Equal) => { let (k, v) = it1.next().unwrap(); - let (_, _) = it2.next().unwrap(); + let (k2, v2) = it2.next().unwrap(); result.push((k, v)); + if preserve_duplicates { + result.push((k2, v2)); + } } Some(Ordering::Greater) => { result.push(it2.next().unwrap()); @@ -1439,6 +1520,24 @@ where result } +fn merge_sorted(input1: I1, input2: I2) -> Vec<(K, V)> +where + K: Ord, + I1: IntoIterator, + I2: IntoIterator, +{ + merge_sorted_help(input1, input2, false) +} + +fn merge_sorted_preserving_duplicates(input1: I1, input2: I2) -> Vec<(K, V)> +where + K: Ord, + I1: IntoIterator, + I2: IntoIterator, +{ + merge_sorted_help(input1, input2, true) +} + fn separate(input1: I1, input2: I2) -> Separate where K: Ord, @@ -1501,19 +1600,6 @@ fn separate_union_tags( (separate(it1, it2), new_ext1, new_ext2) } -fn separate_union_lambdas( - subs: &Subs, - fields1: UnionLambdas, - fields2: UnionLambdas, -) -> Separate { - debug_assert!(fields1.is_sorted(subs)); - debug_assert!(fields2.is_sorted(subs)); - let it1 = fields1.iter_all().map(|(s, vars)| (subs[s], subs[vars])); - let it2 = fields2.iter_all().map(|(s, vars)| (subs[s], subs[vars])); - - separate(it1, it2) -} - #[derive(Debug, Copy, Clone)] enum Rec { None,