diff --git a/crates/compiler/solve/tests/solve_expr.rs b/crates/compiler/solve/tests/solve_expr.rs index 2abc12d6cc..140b8efd9c 100644 --- a/crates/compiler/solve/tests/solve_expr.rs +++ b/crates/compiler/solve/tests/solve_expr.rs @@ -7841,4 +7841,18 @@ mod solve_expr { "hasher -> hasher | hasher has Hasher", ); } + + #[test] + fn dispatch_tag_union_function_inferred() { + infer_eq_without_problem( + indoc!( + r#" + g = if Bool.true then A else B + + g "" + "# + ), + "[A Str, B Str]*", + ); + } } diff --git a/crates/compiler/unify/src/unify.rs b/crates/compiler/unify/src/unify.rs index fcac3b8a73..bd793af453 100644 --- a/crates/compiler/unify/src/unify.rs +++ b/crates/compiler/unify/src/unify.rs @@ -2685,34 +2685,20 @@ fn unify_flat_type( false, ) } - (FunctionOrTagUnion(tag_names_1, _, ext1), FunctionOrTagUnion(tag_names_2, _, ext2)) => { - let tag_name_1_ref = &env.subs.get_subs_slice(*tag_names_1); - let tag_name_2_ref = &env.subs.get_subs_slice(*tag_names_2); - - if tag_name_1_ref == tag_name_2_ref { - let outcome = unify_pool(env, pool, *ext1, *ext2, ctx.mode); - if outcome.mismatches.is_empty() { - let content = *env.subs.get_content_without_compacting(ctx.second); - merge(env, ctx, content) - } else { - outcome - } - } else { - let empty_tag_var_slices_1 = SubsSlice::extend_new( - &mut env.subs.variable_slices, - std::iter::repeat(Default::default()).take(tag_names_1.len()), - ); - let tags1 = UnionTags::from_slices(*tag_names_1, empty_tag_var_slices_1); - - let empty_tag_var_slices_2 = SubsSlice::extend_new( - &mut env.subs.variable_slices, - std::iter::repeat(Default::default()).take(tag_names_2.len()), - ); - let tags2 = UnionTags::from_slices(*tag_names_2, empty_tag_var_slices_2); - - unify_tag_unions(env, pool, ctx, tags1, *ext1, tags2, *ext2, Rec::None) - } - } + ( + FunctionOrTagUnion(tag_names_1, tag_symbols_1, ext1), + FunctionOrTagUnion(tag_names_2, tag_symbols_2, ext2), + ) => unify_two_function_or_tag_unions( + env, + pool, + ctx, + *tag_names_1, + *tag_symbols_1, + *ext1, + *tag_names_2, + *tag_symbols_2, + *ext2, + ), (TagUnion(tags1, ext1), FunctionOrTagUnion(tag_names, _, ext2)) => { let empty_tag_var_slices = SubsSlice::extend_new( &mut env.subs.variable_slices, @@ -3231,3 +3217,52 @@ fn unify_function_or_tag_union_and_func( outcome } + +fn unify_two_function_or_tag_unions( + env: &mut Env, + pool: &mut Pool, + ctx: &Context, + tag_names_1: SubsSlice, + tag_symbols_1: SubsSlice, + ext1: Variable, + tag_names_2: SubsSlice, + tag_symbols_2: SubsSlice, + ext2: Variable, +) -> Outcome { + let merged_tags = { + let mut all_tags: Vec<_> = (env.subs.get_subs_slice(tag_names_1).iter()) + .chain(env.subs.get_subs_slice(tag_names_2)) + .cloned() + .collect(); + all_tags.sort(); + all_tags.dedup(); + SubsSlice::extend_new(&mut env.subs.tag_names, all_tags) + }; + let merged_lambdas = { + let mut all_lambdas: Vec<_> = (env.subs.get_subs_slice(tag_symbols_1).iter()) + .chain(env.subs.get_subs_slice(tag_symbols_2)) + .cloned() + .collect(); + all_lambdas.sort(); + all_lambdas.dedup(); + SubsSlice::extend_new(&mut env.subs.closure_names, all_lambdas) + }; + + let mut outcome = unify_pool(env, pool, ext1, ext2, ctx.mode); + if !outcome.mismatches.is_empty() { + return outcome; + } + + let merge_outcome = merge( + env, + ctx, + Content::Structure(FlatType::FunctionOrTagUnion( + merged_tags, + merged_lambdas, + ext1, + )), + ); + + outcome.union(merge_outcome); + outcome +}