diff --git a/crates/compiler/uitest/tests/recursive_type/generalize_introduced_recursion_variable_issue_4770.txt b/crates/compiler/uitest/tests/recursive_type/generalize_introduced_recursion_variable_issue_4770.txt index 12417c2d10..4402434d49 100644 --- a/crates/compiler/uitest/tests/recursive_type/generalize_introduced_recursion_variable_issue_4770.txt +++ b/crates/compiler/uitest/tests/recursive_type/generalize_introduced_recursion_variable_issue_4770.txt @@ -1,10 +1,10 @@ app "test" provides [main] to "./platform" main = isCorrectOrder (IsList [IsStr ""]) -# ^^^^^^^^^^^^^^ [IsList (List [IsList (List a), IsStr Str]), IsStr Str] -[[isCorrectOrder(1)]]-> Bool +# ^^^^^^^^^^^^^^ [IsList (List a), IsStr Str] as a -[[isCorrectOrder(1)]]-> Bool isCorrectOrder = \pair -> -#^^^^^^^^^^^^^^{-1} [IsList (List [IsList (List a), IsStr *]), IsStr *] -[[isCorrectOrder(1)]]-> Bool +#^^^^^^^^^^^^^^{-1} [IsList (List a), IsStr *] as a -[[isCorrectOrder(1)]]-> Bool when pair is IsList l -> List.all l isCorrectOrder IsStr _ -> isCorrectOrder (IsList [pair]) diff --git a/crates/compiler/unify/src/unify.rs b/crates/compiler/unify/src/unify.rs index fbb9b6a1a0..cd0c834cba 100644 --- a/crates/compiler/unify/src/unify.rs +++ b/crates/compiler/unify/src/unify.rs @@ -2721,7 +2721,6 @@ fn unify_tag_unions( initial_ext1: TagExt, tags2: UnionTags, initial_ext2: TagExt, - recursion_var: Rec, ) -> Outcome { let (separate, mut ext1, mut ext2) = separate_union_tags(env.subs, tags1, initial_ext1, tags2, initial_ext2); @@ -2781,7 +2780,6 @@ fn unify_tag_unions( shared_tags, OtherTags2::Empty, merge_tag_exts(ext1, ext2), - recursion_var, ); shared_tags_outcome.union(ext_outcome); @@ -2818,15 +2816,8 @@ fn unify_tag_unions( let combined_ext = ext1.map(|_| extra_tags_in_2); - let mut shared_tags_outcome = unify_shared_tags( - env, - pool, - ctx, - shared_tags, - OtherTags2::Empty, - combined_ext, - recursion_var, - ); + let mut shared_tags_outcome = + unify_shared_tags(env, pool, ctx, shared_tags, OtherTags2::Empty, combined_ext); shared_tags_outcome.union(ext_outcome); @@ -2881,15 +2872,8 @@ fn unify_tag_unions( let combined_ext = ext2.map(|_| extra_tags_in_1); - let shared_tags_outcome = unify_shared_tags( - env, - pool, - ctx, - shared_tags, - OtherTags2::Empty, - combined_ext, - recursion_var, - ); + let shared_tags_outcome = + unify_shared_tags(env, pool, ctx, shared_tags, OtherTags2::Empty, combined_ext); total_outcome.union(shared_tags_outcome); if extend_ext_with_uninhabited { @@ -2954,8 +2938,7 @@ fn unify_tag_unions( env.subs.commit_snapshot(snapshot); - let shared_tags_outcome = - unify_shared_tags(env, pool, ctx, shared_tags, other_tags, ext, recursion_var); + let shared_tags_outcome = unify_shared_tags(env, pool, ctx, shared_tags, other_tags, ext); total_outcome.union(shared_tags_outcome); total_outcome } @@ -3075,6 +3058,24 @@ fn choose_merged_var(subs: &Subs, var1: Variable, var2: Variable) -> Variable { } } +#[inline] +fn find_union_rec(subs: &Subs, ctx: &Context) -> Rec { + match ( + subs.get_content_without_compacting(ctx.first), + subs.get_content_without_compacting(ctx.second), + ) { + (Structure(s1), Structure(s2)) => match (s1, s2) { + (FlatType::RecursiveTagUnion(l, _, _), FlatType::RecursiveTagUnion(r, _, _)) => { + Rec::Both(*l, *r) + } + (FlatType::RecursiveTagUnion(l, _, _), _) => Rec::Left(*l), + (_, FlatType::RecursiveTagUnion(r, _, _)) => Rec::Right(*r), + _ => Rec::None, + }, + _ => Rec::None, + } +} + #[must_use] fn unify_shared_tags( env: &mut Env, @@ -3083,7 +3084,6 @@ fn unify_shared_tags( shared_tags: Vec<(TagName, (VariableSubsSlice, VariableSubsSlice))>, other_tags: OtherTags2, ext: TagExt, - recursion_var: Rec, ) -> Outcome { let mut matching_tags = Vec::default(); let num_shared_tags = shared_tags.len(); @@ -3192,6 +3192,13 @@ fn unify_shared_tags( } }; + // Look up if either unions are recursive, and if so, what the recursive variable is. + // + // We wait until we're about to merge the unions to do this, since above, while unifying + // payloads, we may have promoted a non-recursive union involved in this unification to + // a recursive one. + let recursion_var = find_union_rec(env.subs, ctx); + let merge_outcome = unify_shared_tags_merge(env, ctx, new_tags, new_ext_var, recursion_var); total_outcome.union(merge_outcome); @@ -3275,24 +3282,20 @@ fn unify_flat_type( } (TagUnion(tags1, ext1), TagUnion(tags2, ext2)) => { - unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2, Rec::None) + unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2) } (RecursiveTagUnion(recursion_var, tags1, ext1), TagUnion(tags2, ext2)) => { debug_assert!(is_recursion_var(env.subs, *recursion_var)); // this never happens in type-correct programs, but may happen if there is a type error - let rec = Rec::Left(*recursion_var); - - unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec) + unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2) } (TagUnion(tags1, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => { debug_assert!(is_recursion_var(env.subs, *recursion_var)); - let rec = Rec::Right(*recursion_var); - - unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec) + unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2) } (RecursiveTagUnion(rec1, tags1, ext1), RecursiveTagUnion(rec2, tags2, ext2)) => { @@ -3307,8 +3310,7 @@ fn unify_flat_type( env.subs.dbg(*rec2) ); - let rec = Rec::Both(*rec1, *rec2); - let mut outcome = unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec); + let mut outcome = unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2); outcome.union(unify_pool(env, pool, *rec1, *rec2, ctx.mode)); outcome @@ -3407,7 +3409,7 @@ fn unify_flat_type( ); let tags2 = UnionTags::from_slices(*tag_names, empty_tag_var_slices); - unify_tag_unions(env, pool, ctx, *tags1, *ext1, tags2, *ext2, Rec::None) + unify_tag_unions(env, pool, ctx, *tags1, *ext1, tags2, *ext2) } (FunctionOrTagUnion(tag_names, _, ext1), TagUnion(tags2, ext2)) => { let empty_tag_var_slices = SubsSlice::extend_new( @@ -3416,7 +3418,7 @@ fn unify_flat_type( ); let tags1 = UnionTags::from_slices(*tag_names, empty_tag_var_slices); - unify_tag_unions(env, pool, ctx, tags1, *ext1, *tags2, *ext2, Rec::None) + unify_tag_unions(env, pool, ctx, tags1, *ext1, *tags2, *ext2) } (RecursiveTagUnion(recursion_var, tags1, ext1), FunctionOrTagUnion(tag_names, _, ext2)) => { @@ -3428,9 +3430,8 @@ fn unify_flat_type( std::iter::repeat(Default::default()).take(tag_names.len()), ); let tags2 = UnionTags::from_slices(*tag_names, empty_tag_var_slices); - let rec = Rec::Left(*recursion_var); - unify_tag_unions(env, pool, ctx, *tags1, *ext1, tags2, *ext2, rec) + unify_tag_unions(env, pool, ctx, *tags1, *ext1, tags2, *ext2) } (FunctionOrTagUnion(tag_names, _, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => { @@ -3441,9 +3442,8 @@ fn unify_flat_type( std::iter::repeat(Default::default()).take(tag_names.len()), ); let tags1 = UnionTags::from_slices(*tag_names, empty_tag_var_slices); - let rec = Rec::Right(*recursion_var); - unify_tag_unions(env, pool, ctx, tags1, *ext1, *tags2, *ext2, rec) + unify_tag_unions(env, pool, ctx, tags1, *ext1, *tags2, *ext2) } // these have underscores because they're unused in --release builds