diff --git a/ast/src/solve_type.rs b/ast/src/solve_type.rs index 6569d2586b..ba8d262a8c 100644 --- a/ast/src/solve_type.rs +++ b/ast/src/solve_type.rs @@ -1408,7 +1408,10 @@ fn adjust_rank_content( } } - LambdaSet(subs::LambdaSet { solved }) => { + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { let mut rank = group_rank; for (_, index) in solved.iter_all() { @@ -1419,6 +1422,15 @@ fn adjust_rank_content( } } + if let Some(rec_var) = recursion_var.into_variable() { + // THEORY: the recursion var has the same rank as the tag union itself + // all types it uses are also in the tags already, so it cannot influence the + // rank + debug_assert!( + rank >= adjust_rank(subs, young_mark, visit_mark, group_rank, rec_var) + ); + } + rank } @@ -1596,7 +1608,14 @@ fn instantiate_rigids_help( instantiate_rigids_help(subs, max_rank, pools, real_type_var); } - LambdaSet(subs::LambdaSet { solved }) => { + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { + if let Some(rec_var) = recursion_var.into_variable() { + instantiate_rigids_help(subs, max_rank, pools, rec_var); + } + for (_, index) in solved.iter_all() { let slice = subs[index]; for var_index in slice { @@ -1872,7 +1891,10 @@ fn deep_copy_var_help( copy } - LambdaSet(subs::LambdaSet { solved }) => { + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { let mut new_variable_slices = Vec::with_capacity(solved.len()); let mut new_variables = Vec::new(); @@ -1898,7 +1920,13 @@ fn deep_copy_var_help( }; let new_solved = UnionTags::from_slices(solved.tag_names(), new_variables); - let new_content = LambdaSet(subs::LambdaSet { solved: new_solved }); + let new_rec_var = + recursion_var.map(|rec_var| deep_copy_var_help(subs, max_rank, pools, rec_var)); + + let new_content = LambdaSet(subs::LambdaSet { + solved: new_solved, + recursion_var: new_rec_var, + }); subs.set(copy, make_descriptor(new_content)); diff --git a/bindgen/src/bindgen.rs b/bindgen/src/bindgen.rs index 7cb0aa27ff..a069ad07d1 100644 --- a/bindgen/src/bindgen.rs +++ b/bindgen/src/bindgen.rs @@ -118,9 +118,10 @@ fn add_type_help<'a>( fields, }) } - Content::LambdaSet(LambdaSet { solved }) => { - add_tag_union(env, opt_name, solved, var, types) - } + Content::LambdaSet(LambdaSet { + solved, + recursion_var: _, + }) => add_tag_union(env, opt_name, solved, var, types), Content::Structure(FlatType::TagUnion(tags, ext_var)) => { debug_assert!(ext_var_is_empty_tag_union(subs, *ext_var)); diff --git a/compiler/mono/src/copy.rs b/compiler/mono/src/copy.rs index 732e7f2fce..4df1ab082f 100644 --- a/compiler/mono/src/copy.rs +++ b/compiler/mono/src/copy.rs @@ -607,7 +607,11 @@ fn deep_copy_type_vars<'a>( }) } - LambdaSet(subs::LambdaSet { solved }) => { + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { + let new_rec_var = recursion_var.map(|var| descend_var!(var)); for variables_slice_index in solved.variables() { let variables_slice = subs[variables_slice_index]; descend_slice!(variables_slice); @@ -626,7 +630,10 @@ fn deep_copy_type_vars<'a>( let new_solved = UnionTags::from_slices(solved.tag_names(), new_variable_slices); - LambdaSet(subs::LambdaSet { solved: new_solved }) + LambdaSet(subs::LambdaSet { + solved: new_solved, + recursion_var: new_rec_var, + }) }) } diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index f42e3241db..113a5bbcd4 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -2739,7 +2739,10 @@ fn generate_runtime_error_function<'a>( ) .unwrap(); - eprintln!("emitted runtime error function {:?}", &msg); + eprintln!( + "emitted runtime error function {:?} for layout {:?}", + &msg, layout + ); let runtime_error = Stmt::RuntimeError(msg.into_bump_str()); @@ -2911,6 +2914,7 @@ fn specialize_external<'a>( fn_var, roc_unify::unify::Mode::EQ, ); + dbg!(&_unified); // This will not hold for programs with type errors // let is_valid = matches!(unified, roc_unify::unify::Unified::Success(_)); @@ -2928,8 +2932,13 @@ fn specialize_external<'a>( } }; - let specialized = - build_specialized_proc_from_var(env, layout_cache, proc_name, pattern_symbols, fn_var)?; + let specialized = dbg!(build_specialized_proc_from_var( + env, + layout_cache, + proc_name, + pattern_symbols, + fn_var + ))?; // determine the layout of aliases/rigids exposed to the host let host_exposed_layouts = if host_exposed_variables.is_empty() { @@ -3537,6 +3546,7 @@ where Ok((proc, raw)) } Err(error) => { + dbg!(&error); env.subs.rollback_to(snapshot); layout_cache.rollback_to(cache_snapshot); diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 9752922f35..e8a2ad0258 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -72,7 +72,7 @@ impl<'a> RawFunctionLayout<'a> { content: Content, ) -> Result { use roc_types::subs::Content::*; - match content { + match dbg!(content) { FlexVar(_) | RigidVar(_) => Err(LayoutProblem::UnresolvedTypeVar(var)), FlexAbleVar(_, _) | RigidAbleVar(_, _) => todo_abilities!("Not reachable yet"), RecursionVar { structure, .. } => { @@ -157,8 +157,9 @@ impl<'a> RawFunctionLayout<'a> { lset: subs::LambdaSet, ) -> Result { // Lambda set is just a tag union from the layout's perspective. - let subs::LambdaSet { solved } = lset; - Self::layout_from_flat_type(env, FlatType::TagUnion(solved, Variable::EMPTY_TAG_UNION)) + dbg!( + Self::layout_from_flat_type(env, dbg!(lset.as_tag_union())) + ) } fn layout_from_flat_type( @@ -1696,8 +1697,7 @@ fn layout_from_lambda_set<'a>( lset: subs::LambdaSet, ) -> Result, LayoutProblem> { // Lambda set is just a tag union from the layout's perspective. - let subs::LambdaSet { solved } = lset; - layout_from_flat_type(env, FlatType::TagUnion(solved, Variable::EMPTY_TAG_UNION)) + layout_from_flat_type(env, lset.as_tag_union()) } fn layout_from_flat_type<'a>( diff --git a/compiler/mono/src/layout_soa.rs b/compiler/mono/src/layout_soa.rs index acb381d9ab..43771fb422 100644 --- a/compiler/mono/src/layout_soa.rs +++ b/compiler/mono/src/layout_soa.rs @@ -156,12 +156,7 @@ impl FunctionLayout { lset: subs::LambdaSet, ) -> Result { // Lambda set is just a tag union from the layout's perspective. - let subs::LambdaSet { solved } = lset; - Self::from_flat_type( - layouts, - subs, - &FlatType::TagUnion(solved, Variable::EMPTY_TAG_UNION), - ) + Self::from_flat_type(layouts, subs, &lset.as_tag_union()) } fn from_flat_type( @@ -281,12 +276,7 @@ impl LambdaSet { lset: subs::LambdaSet, ) -> Result { // Lambda set is just a tag union from the layout's perspective. - let subs::LambdaSet { solved } = lset; - Self::from_flat_type( - layouts, - subs, - &FlatType::TagUnion(solved, Variable::EMPTY_TAG_UNION), - ) + Self::from_flat_type(layouts, subs, &lset.as_tag_union()) } fn from_flat_type( @@ -730,12 +720,7 @@ impl Layout { lset: subs::LambdaSet, ) -> Result { // Lambda set is just a tag union from the layout's perspective. - let subs::LambdaSet { solved } = lset; - Self::from_flat_type( - layouts, - subs, - &FlatType::TagUnion(solved, Variable::EMPTY_TAG_UNION), - ) + Self::from_flat_type(layouts, subs, &lset.as_tag_union()) } fn from_flat_type( diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index 3cb13388e0..8822f754a2 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -1893,7 +1893,12 @@ fn type_to_variable<'a>( Content::Structure(FlatType::EmptyTagUnion), )); - let content = Content::LambdaSet(subs::LambdaSet { solved }); + let content = Content::LambdaSet(subs::LambdaSet { + solved, + // We may figure out the lambda set is recursive during solving, but it never + // is to begin with. + recursion_var: OptVariable::NONE, + }); register_with_known_var(subs, destination, rank, pools, content) } @@ -2528,13 +2533,24 @@ fn check_for_infinite_type( let var = loc_var.value; while let Err((recursive, _chain)) = subs.occurs(var) { - // try to make a tag union recursive, see if that helps + // try to make a union recursive, see if that helps match subs.get_content_without_compacting(recursive) { &Content::Structure(FlatType::TagUnion(tags, ext_var)) => { + dbg!(1); subs.mark_tag_union_recursive(recursive, tags, ext_var); } + &Content::LambdaSet(subs::LambdaSet { + solved, + recursion_var: _, + }) => { + dbg!(2); + subs.mark_lambda_set_recursive(recursive, solved); + } - _other => circular_error(subs, problems, symbol, &loc_var), + _other => { + dbg!(3); + circular_error(subs, problems, symbol, &loc_var) + } } } } @@ -2875,7 +2891,10 @@ fn adjust_rank_content( rank } - LambdaSet(subs::LambdaSet { solved }) => { + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { let mut rank = group_rank; for (_, index) in solved.iter_all() { @@ -2886,6 +2905,26 @@ fn adjust_rank_content( } } + if let (true, Some(rec_var)) = (cfg!(debug_assertions), recursion_var.into_variable()) { + // THEORY: unlike the situation for recursion vars under recursive tag unions, + // recursive vars inside lambda sets can't escape into higher let-generalized regions + // because lambda sets aren't user-facing. + // + // So the recursion var should be fully accounted by everything else in the lambda set + // (since it appears in the lambda set), and if the rank is higher, it's either a + // bug or our theory is wrong and indeed they can escape into higher regions. + let rec_var_rank = adjust_rank(subs, young_mark, visit_mark, group_rank, rec_var); + + debug_assert!( + rank >= rec_var_rank, + "rank was {:?} but recursion var <{:?}>{:?} has higher rank {:?}", + rank, + rec_var, + subs.get_content_without_compacting(rec_var), + rec_var_rank + ); + } + rank } @@ -3041,11 +3080,18 @@ fn instantiate_rigids_help(subs: &mut Subs, max_rank: Rank, initial: Variable) { stack.push(var); } - LambdaSet(subs::LambdaSet { solved }) => { + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { for slice_index in solved.variables() { let slice = subs.variable_slices[slice_index.index as usize]; stack.extend(var_slice!(slice)); } + + if let Some(rec_var) = recursion_var.into_variable() { + stack.push(rec_var); + } } &RangedNumber(typ, _) => { stack.push(typ); @@ -3306,10 +3352,20 @@ fn deep_copy_var_help( subs.set_content_unchecked(copy, new_content); } - LambdaSet(subs::LambdaSet { solved }) => { + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { let new_solved = copy_union_tags!(solved); + let new_rec_var = recursion_var.map(|v| work!(v)); - subs.set_content_unchecked(copy, LambdaSet(subs::LambdaSet { solved: new_solved })); + subs.set_content_unchecked( + copy, + LambdaSet(subs::LambdaSet { + solved: new_solved, + recursion_var: new_rec_var, + }), + ); } RangedNumber(typ, range) => { diff --git a/compiler/solve/tests/solve_expr.rs b/compiler/solve/tests/solve_expr.rs index 2ccd9356e6..de8feba3b5 100644 --- a/compiler/solve/tests/solve_expr.rs +++ b/compiler/solve/tests/solve_expr.rs @@ -6500,7 +6500,37 @@ mod solve_expr { r#" app "test" provides [main] to "./platform" - main = Result.mapErr + greeting = + hi = "Hello" + name = "World" + + "\(hi), \(name)!" + + main = + when nestHelp 4 is + _ -> greeting + + nestHelp : I64 -> XEffect {} + nestHelp = \m -> + when m is + 0 -> + always {} + + _ -> + always {} |> after \_ -> nestHelp (m - 1) + + + XEffect a := {} -> a + + always : a -> XEffect a + always = \x -> @XEffect (\{} -> x) + + after : XEffect a, (a -> XEffect b) -> XEffect b + after = \(@XEffect e), toB -> + @XEffect \{} -> + when toB (e {}) is + @XEffect e2 -> + e2 {} "# ), "", diff --git a/compiler/test_gen/src/gen_primitives.rs b/compiler/test_gen/src/gen_primitives.rs index e0a6c87992..37444ef976 100644 --- a/compiler/test_gen/src/gen_primitives.rs +++ b/compiler/test_gen/src/gen_primitives.rs @@ -3195,7 +3195,6 @@ fn alias_defined_out_of_order() { } #[test] -#[ignore = "recursive lambda set"] #[cfg(any(feature = "gen-llvm"))] fn recursively_build_effect() { assert_evals_to!( diff --git a/compiler/types/src/pretty_print.rs b/compiler/types/src/pretty_print.rs index 31a243f9a2..b32d864d39 100644 --- a/compiler/types/src/pretty_print.rs +++ b/compiler/types/src/pretty_print.rs @@ -221,7 +221,22 @@ fn find_names_needed( // TODO should we also look in the actual variable? // find_names_needed(_actual, subs, roots, root_appearances, names_taken); } - LambdaSet(subs::LambdaSet { solved: _ }) => {} + LambdaSet(subs::LambdaSet { + solved, + recursion_var, + }) => { + for slice_index in solved.variables() { + let slice = subs[slice_index]; + for var_index in slice { + let var = subs[var_index]; + find_names_needed(var, subs, roots, root_appearances, names_taken); + } + } + + if let Some(rec_var) = recursion_var.into_variable() { + find_names_needed(rec_var, subs, roots, root_appearances, names_taken); + } + } &RangedNumber(typ, _) => { find_names_needed(typ, subs, roots, root_appearances, names_taken); } @@ -944,7 +959,10 @@ pub fn resolve_lambda_set<'a>( fields: &mut Vec<(TagName, Vec)>, ) { match subs.get_content_without_compacting(var) { - Content::LambdaSet(subs::LambdaSet { solved }) => { + Content::LambdaSet(subs::LambdaSet { + solved, + recursion_var: _, + }) => { push_union_tags(subs, solved, fields); } c => internal_error!("called with a non-lambda set {:?}", c), diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index 8178817723..90df5d3bac 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -718,11 +718,15 @@ fn subs_fmt_content(this: &Content, subs: &Subs, f: &mut fmt::Formatter) -> fmt: SubsFmtContent(subs.get_content_without_compacting(*actual), subs) ) } - Content::LambdaSet(LambdaSet { solved }) => { + Content::LambdaSet(LambdaSet { + solved, + recursion_var, + }) => { write!( f, - "LambdaSet({:?})", - SubsFmtUnionTags(solved, Variable::EMPTY_TAG_UNION, subs) + "LambdaSet({:?}, <{:?}>)", + SubsFmtUnionTags(solved, Variable::EMPTY_TAG_UNION, subs), + recursion_var ) } Content::RangedNumber(typ, range) => { @@ -784,7 +788,9 @@ fn subs_fmt_flat_type(this: &FlatType, subs: &Subs, f: &mut fmt::Formatter) -> f let lambda_content = subs.get_content_without_compacting(*lambda_set); write!( f, - "], {:?}, <{:?}>{:?})", + "], <{:?}={:?}>{:?}, <{:?}>{:?})", + lambda_set, + subs.get_root_key_without_compacting(*lambda_set), SubsFmtContent(lambda_content, subs), *result, SubsFmtContent(result_content, subs) @@ -908,6 +914,16 @@ impl OptVariable { Variable(self.0) } } + + pub fn map(self, f: F) -> OptVariable + where + F: FnOnce(Variable) -> Variable, + { + self.into_variable() + .map(f) + .map(OptVariable::from) + .unwrap_or(OptVariable::NONE) + } } impl fmt::Debug for OptVariable { @@ -1759,6 +1775,30 @@ impl Subs { tags: UnionTags, ext_var: Variable, ) { + let (rec_var, new_tags) = self.mark_union_recursive_help(recursive, tags); + + let new_ext_var = self.explicit_substitute(recursive, rec_var, ext_var); + let flat_type = FlatType::RecursiveTagUnion(rec_var, new_tags, new_ext_var); + + self.set_content(recursive, Content::Structure(flat_type)); + } + + pub fn mark_lambda_set_recursive(&mut self, recursive: Variable, solved_lambdas: UnionTags) { + let (rec_var, new_tags) = self.mark_union_recursive_help(recursive, solved_lambdas); + + let new_lambda_set = Content::LambdaSet(LambdaSet { + solved: new_tags, + recursion_var: OptVariable::from(rec_var), + }); + + self.set_content(recursive, new_lambda_set); + } + + fn mark_union_recursive_help( + &mut self, + recursive: Variable, + tags: UnionTags, + ) -> (Variable, UnionTags) { let description = self.get(recursive); let rec_var = self.fresh_unnamed_flex_var(); @@ -1786,13 +1826,9 @@ impl Subs { self.variable_slices[variable_slice_index] = new_variables; } - let new_ext_var = self.explicit_substitute(recursive, rec_var, ext_var); - let new_tags = UnionTags::from_slices(tags.tag_names(), new_variable_slices); - let flat_type = FlatType::RecursiveTagUnion(rec_var, new_tags, new_ext_var); - - self.set_content(recursive, Content::Structure(flat_type)); + (rec_var, new_tags) } pub fn explicit_substitute( @@ -2020,6 +2056,45 @@ pub enum Content { pub struct LambdaSet { /// The resolved lambda symbols we know. pub solved: UnionTags, + /// Lambda sets may be recursive. For example, consider the annotated program + /// + /// ```text + /// XEffect : A -> B + /// + /// after : ({} -> XEffect) -> XEffect + /// after = + /// \cont -> + /// f = \A -[`f (typeof cont)]-> when cont {} is A -> B + /// f + /// + /// nestForever : {} -> XEffect + /// nestForever = \{} -[`nestForever]-> after nestForever + /// ^^^^^^^^^^^ {} -[`nestForever]-> A -[`f ({} -[`nestForever]-> A -[`f ...]-> B)]-> B + /// ``` + /// + /// where [`nestForever] and [`f ...] refer to the lambda sets of their respective arrows. `f` + /// captures `cont`. The usage of `after` in `nestForever` means that `nestForever` has type + /// ``nestForever : {} -[`nestForever]-> A -[`f (typeof cont)]-> B``. But also, `after` is called + /// with ``nestForever`, which means in this case `typeof cont = typeof nestForever``. So we see + /// that ``nestForever : {} -[`nestForever]-> A -[`f (typeof nestForever)]-> B``, and the lambda + /// set ``[`f (typeof nestForever)]`` is recursive. + /// + /// However, we don't know if a lambda set is recursive or not until type inference. + pub recursion_var: OptVariable, +} + +impl LambdaSet { + /// For things like layout generation, we don't care about differentiating betweent lambda sets + /// and tag unions - this function normalizes lambda sets appropriately as a possibly-recursive + /// tag union. + pub fn as_tag_union(&self) -> FlatType { + match self.recursion_var.into_variable() { + Some(rec_var) => { + FlatType::RecursiveTagUnion(rec_var, self.solved, Variable::EMPTY_TAG_UNION) + } + None => FlatType::TagUnion(self.solved, Variable::EMPTY_TAG_UNION), + } + } } #[derive(Clone, Copy, Debug, Default)] @@ -2803,10 +2878,19 @@ fn occurs( Ok(()) } - LambdaSet(self::LambdaSet { solved }) => { + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => { let mut new_seen = seen.to_owned(); new_seen.push(root_var); + if include_recursion_var { + if let Some(v) = recursion_var.into_variable() { + new_seen.push(subs.get_root_key_without_compacting(v)); + } + } + occurs_union_tags(subs, root_var, &new_seen, include_recursion_var, solved) } RangedNumber(typ, _range_vars) => { @@ -2884,120 +2968,126 @@ fn explicit_substitute( use self::Content::*; use self::FlatType::*; let in_root = subs.get_root_key(in_var); - if seen.contains(&in_root) { + if subs.get_root_key(from) == in_root { + to + } else if seen.contains(&in_root) { in_var } else { seen.insert(in_root); - if subs.get_root_key(from) == subs.get_root_key(in_var) { - to - } else { - match subs.get(in_var).content { - FlexVar(_) - | RigidVar(_) - | FlexAbleVar(_, _) - | RigidAbleVar(_, _) - | RecursionVar { .. } - | Error => in_var, + match subs.get(in_var).content { + FlexVar(_) + | RigidVar(_) + | FlexAbleVar(_, _) + | RigidAbleVar(_, _) + | RecursionVar { .. } + | Error => in_var, - Structure(flat_type) => { - match flat_type { - Apply(symbol, args) => { - for var_index in args.into_iter() { - let var = subs[var_index]; - let answer = explicit_substitute(subs, from, to, var, seen); - subs[var_index] = answer; - } - - subs.set_content(in_var, Structure(Apply(symbol, args))); - } - Func(arg_vars, closure_var, ret_var) => { - for var_index in arg_vars.into_iter() { - let var = subs[var_index]; - let answer = explicit_substitute(subs, from, to, var, seen); - subs[var_index] = answer; - } - - let new_ret_var = explicit_substitute(subs, from, to, ret_var, seen); - let new_closure_var = - explicit_substitute(subs, from, to, closure_var, seen); - - subs.set_content( - in_var, - Structure(Func(arg_vars, new_closure_var, new_ret_var)), - ); - } - TagUnion(tags, ext_var) => { - let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); - - let union_tags = - explicit_substitute_union_tags(subs, from, to, tags, seen); - - subs.set_content(in_var, Structure(TagUnion(union_tags, new_ext_var))); - } - FunctionOrTagUnion(tag_name, symbol, ext_var) => { - let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); - subs.set_content( - in_var, - Structure(FunctionOrTagUnion(tag_name, symbol, new_ext_var)), - ); - } - RecursiveTagUnion(rec_var, tags, ext_var) => { - // NOTE rec_var is not substituted, verify that this is correct! - let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); - - let union_tags = - explicit_substitute_union_tags(subs, from, to, tags, seen); - - subs.set_content( - in_var, - Structure(RecursiveTagUnion(rec_var, union_tags, new_ext_var)), - ); - } - Record(vars_by_field, ext_var) => { - let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); - - for index in vars_by_field.iter_variables() { - let var = subs[index]; - let new_var = explicit_substitute(subs, from, to, var, seen); - subs[index] = new_var; - } - - subs.set_content(in_var, Structure(Record(vars_by_field, new_ext_var))); + Structure(flat_type) => { + match flat_type { + Apply(symbol, args) => { + for var_index in args.into_iter() { + let var = subs[var_index]; + let answer = explicit_substitute(subs, from, to, var, seen); + subs[var_index] = answer; } - EmptyRecord | EmptyTagUnion | Erroneous(_) => {} + subs.set_content(in_var, Structure(Apply(symbol, args))); + } + Func(arg_vars, closure_var, ret_var) => { + for var_index in arg_vars.into_iter() { + let var = subs[var_index]; + let answer = explicit_substitute(subs, from, to, var, seen); + subs[var_index] = answer; + } + + let new_ret_var = explicit_substitute(subs, from, to, ret_var, seen); + let new_closure_var = + explicit_substitute(subs, from, to, closure_var, seen); + + subs.set_content( + in_var, + Structure(Func(arg_vars, new_closure_var, new_ret_var)), + ); + } + TagUnion(tags, ext_var) => { + let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); + + let union_tags = explicit_substitute_union_tags(subs, from, to, tags, seen); + + subs.set_content(in_var, Structure(TagUnion(union_tags, new_ext_var))); + } + FunctionOrTagUnion(tag_name, symbol, ext_var) => { + let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); + subs.set_content( + in_var, + Structure(FunctionOrTagUnion(tag_name, symbol, new_ext_var)), + ); + } + RecursiveTagUnion(rec_var, tags, ext_var) => { + // NOTE rec_var is not substituted, verify that this is correct! + let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); + + let union_tags = explicit_substitute_union_tags(subs, from, to, tags, seen); + + subs.set_content( + in_var, + Structure(RecursiveTagUnion(rec_var, union_tags, new_ext_var)), + ); + } + Record(vars_by_field, ext_var) => { + let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); + + for index in vars_by_field.iter_variables() { + let var = subs[index]; + let new_var = explicit_substitute(subs, from, to, var, seen); + subs[index] = new_var; + } + + subs.set_content(in_var, Structure(Record(vars_by_field, new_ext_var))); } - in_var + EmptyRecord | EmptyTagUnion | Erroneous(_) => {} } - Alias(symbol, args, actual, kind) => { - for index in args.into_iter() { - let var = subs[index]; - let new_var = explicit_substitute(subs, from, to, var, seen); - subs[index] = new_var; - } - let new_actual = explicit_substitute(subs, from, to, actual, seen); - - subs.set_content(in_var, Alias(symbol, args, new_actual, kind)); - - in_var + in_var + } + Alias(symbol, args, actual, kind) => { + for index in args.into_iter() { + let var = subs[index]; + let new_var = explicit_substitute(subs, from, to, var, seen); + subs[index] = new_var; } - LambdaSet(self::LambdaSet { solved }) => { - let new_solved = explicit_substitute_union_tags(subs, from, to, solved, seen); - subs.set_content(in_var, LambdaSet(self::LambdaSet { solved: new_solved })); + let new_actual = explicit_substitute(subs, from, to, actual, seen); - in_var - } - RangedNumber(typ, range) => { - let new_typ = explicit_substitute(subs, from, to, typ, seen); + subs.set_content(in_var, Alias(symbol, args, new_actual, kind)); - subs.set_content(in_var, RangedNumber(new_typ, range)); + in_var + } + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => { + // NOTE recursion_var is not substituted, verify that this is correct! + let new_solved = explicit_substitute_union_tags(subs, from, to, solved, seen); - in_var - } + subs.set_content( + in_var, + LambdaSet(self::LambdaSet { + solved: new_solved, + recursion_var, + }), + ); + + in_var + } + RangedNumber(typ, range) => { + let new_typ = explicit_substitute(subs, from, to, typ, seen); + + subs.set_content(in_var, RangedNumber(new_typ, range)); + + in_var } } } @@ -3092,8 +3182,15 @@ fn get_var_names( get_var_names(subs, subs[arg_var], answer) }), - LambdaSet(self::LambdaSet { solved }) => { - get_var_names_union_tags(subs, solved, taken_names) + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => { + let taken_names = get_var_names_union_tags(subs, solved, taken_names); + match recursion_var.into_variable() { + Some(v) => get_var_names(subs, v, taken_names), + None => taken_names, + } } RangedNumber(typ, _) => get_var_names(subs, typ, taken_names), @@ -3332,9 +3429,19 @@ fn content_to_err_type( ErrorType::Alias(symbol, err_args, Box::new(err_type), kind) } - LambdaSet(self::LambdaSet { solved }) => { - ErrorType::TagUnion(union_tags_to_err_tags(subs, state, solved), TypeExt::Closed) - } + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => match recursion_var.into_variable() { + None => { + ErrorType::TagUnion(union_tags_to_err_tags(subs, state, solved), TypeExt::Closed) + } + Some(v) => ErrorType::RecursiveTagUnion( + Box::new(var_to_err_type(subs, state, v)), + union_tags_to_err_tags(subs, state, solved), + TypeExt::Closed, + ), + }, RangedNumber(typ, range) => { let err_type = var_to_err_type(subs, state, typ); @@ -3640,11 +3747,16 @@ fn restore_help(subs: &mut Subs, initial: Variable) { stack.push(*var); } - LambdaSet(self::LambdaSet { solved }) => { + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => { for slice_index in solved.variables() { let slice = variable_slices[slice_index.index as usize]; stack.extend(var_slice(slice)); } + + recursion_var.into_variable().map(|v| stack.push(v)); } RangedNumber(typ, _vars) => { @@ -3834,8 +3946,12 @@ impl StorageSubs { Self::offset_variable(offsets, *actual), *kind, ), - LambdaSet(self::LambdaSet { solved }) => LambdaSet(self::LambdaSet { + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => LambdaSet(self::LambdaSet { solved: Self::offset_union_tags(offsets, *solved), + recursion_var: recursion_var.map(|v| Self::offset_variable(offsets, v)), }), RangedNumber(typ, range) => RangedNumber(Self::offset_variable(offsets, *typ), *range), Error => Content::Error, @@ -4236,10 +4352,17 @@ fn deep_copy_var_to_help(env: &mut DeepCopyVarToEnv<'_>, var: Variable) -> Varia copy } - LambdaSet(self::LambdaSet { solved }) => { + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => { let new_solved = deep_copy_union_tags(env, solved); + let new_rec_var = recursion_var.map(|v| deep_copy_var_to_help(env, v)); - let new_content = LambdaSet(self::LambdaSet { solved: new_solved }); + let new_content = LambdaSet(self::LambdaSet { + solved: new_solved, + recursion_var: new_rec_var, + }); env.target.set(copy, make_descriptor(new_content)); copy } @@ -4680,10 +4803,18 @@ fn copy_import_to_help(env: &mut CopyImportEnv<'_>, max_rank: Rank, var: Variabl copy } - LambdaSet(self::LambdaSet { solved }) => { + LambdaSet(self::LambdaSet { + solved, + recursion_var, + }) => { let new_solved = copy_union_tags(env, max_rank, solved); + let new_rec_var = + recursion_var.map(|rec_var| copy_import_to_help(env, max_rank, rec_var)); - let new_content = LambdaSet(self::LambdaSet { solved: new_solved }); + let new_content = LambdaSet(self::LambdaSet { + solved: new_solved, + recursion_var: new_rec_var, + }); env.target.set(copy, make_descriptor(new_content)); diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index 3365a18869..469d6bdc4b 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -343,8 +343,8 @@ fn debug_print_unified_types(subs: &mut Subs, ctx: &Context, opt_outcome: Option "{}{}({:?}-{:?}): {:?} {:?} {} {:?} {:?}", " ".repeat(use_depth), prefix, - ctx.first, - ctx.second, + subs.get_root_key_without_compacting(ctx.first), + subs.get_root_key_without_compacting(ctx.second), ctx.first, SubsFmtContent(&content_1, subs), mode, @@ -801,11 +801,14 @@ fn unify_lambda_set( Content::LambdaSet(other_lambda_set) => { unify_lambda_set_help(subs, pool, ctx, lambda_set, *other_lambda_set) } + RecursionVar { structure, .. } => { + // suppose that the recursion var is a lambda set + unify_pool(subs, pool, ctx.first, *structure, ctx.mode) + } RigidVar(..) | RigidAbleVar(..) => mismatch!("Lambda sets never unify with rigid"), FlexAbleVar(..) => mismatch!("Lambda sets should never have abilities attached to them"), Structure(..) => mismatch!("Lambda set cannot unify with non-lambda set structure"), RangedNumber(..) => mismatch!("Lambda sets are never numbers"), - RecursionVar { .. } => mismatch!("Lambda set not expected to be recursive!"), Alias(..) => mismatch!("Lambda set can never be directly under an alias!"), Error => merge(subs, ctx, Error), } @@ -822,8 +825,21 @@ fn unify_lambda_set_help( // LambdaSets unify like TagUnions, but can grow unbounded regardless of the extension // variable. - let LambdaSet { solved: solved1 } = lset1; - let LambdaSet { solved: solved2 } = lset2; + let LambdaSet { + solved: solved1, + recursion_var: rec1, + } = lset1; + let LambdaSet { + solved: solved2, + recursion_var: rec2, + } = lset2; + + debug_assert!( + (rec1.into_variable().into_iter()) + .chain(rec2.into_variable().into_iter()) + .all(|v| is_recursion_var(subs, v)), + "Recursion var is present, but it doesn't have a recursive content!" + ); let (separate_solved, _, _) = separate_union_tags( subs, @@ -843,7 +859,7 @@ fn unify_lambda_set_help( let mut joined_lambdas = vec![]; for (tag_name, (vars1, vars2)) in in_both { - let mut joined_vars = vec![]; + let mut matching_vars = vec![]; if vars1.len() != vars2.len() { continue; // this is a type mismatch; not adding the tag will trigger it below. @@ -853,16 +869,32 @@ fn unify_lambda_set_help( for (var1, var2) in (vars1.into_iter()).zip(vars2.into_iter()) { let (var1, var2) = (subs[var1], subs[var2]); + // Lambda sets are effectively tags under another name, and their usage can also result + // in the arguments of a lambda name being recursive. It very well may happen that + // during unification, a lambda set previously marked as not recursive becomes + // recursive. See the docs of [LambdaSet] for one example, or https://github.com/rtfeldman/roc/pull/2307. + // + // Like with tag unions, if it has, we'll always pass through this branch. So, take + // this opportunity to promote the lambda set to recursive if need be. + maybe_mark_union_recursive(subs, var1); + maybe_mark_union_recursive(subs, var2); + let outcome = unify_pool(subs, pool, var1, var2, ctx.mode); - if outcome.mismatches.is_empty() { - // otherwise this is a type mismatch; not adding the variable will trigger it below. - joined_vars.push(var1); + // TODO: i think we can get rid of this + // clearly, this is very suspicious: these variables have just been unified. And yet, + // not doing this leads to stack overflows + if rec2.is_some() { + if outcome.mismatches.is_empty() { + matching_vars.push(var2); + } + } else if outcome.mismatches.is_empty() { + matching_vars.push(var1); } } - if joined_vars.len() == num_vars { - joined_lambdas.push((tag_name, joined_vars)); + if matching_vars.len() == num_vars { + joined_lambdas.push((tag_name, matching_vars)); } } @@ -885,8 +917,17 @@ fn unify_lambda_set_help( }), ); + let recursion_var = match (rec1.into_variable(), rec2.into_variable()) { + // Prefer left when it's available. + (Some(rec), _) | (_, Some(rec)) => OptVariable::from(rec), + (None, None) => OptVariable::NONE, + }; + let new_solved = UnionTags::insert_into_subs(subs, all_lambdas); - let new_lambda_set = Content::LambdaSet(LambdaSet { solved: new_solved }); + let new_lambda_set = Content::LambdaSet(LambdaSet { + solved: new_solved, + recursion_var, + }); merge(subs, ctx, new_lambda_set) } else { @@ -1483,31 +1524,52 @@ enum OtherTags2 { ), } -fn maybe_mark_tag_union_recursive(subs: &mut Subs, tag_union_var: Variable) { - 'outer: while let Err((recursive, chain)) = subs.occurs(tag_union_var) { +/// Promotes a non-recursive tag union or lambda set to its recursive variant, if it is found to be +/// recursive. +fn maybe_mark_union_recursive(subs: &mut Subs, union_var: Variable) { + 'outer: while let Err((recursive, chain)) = subs.occurs(union_var) { let description = subs.get(recursive); - if let Content::Structure(FlatType::TagUnion(tags, ext_var)) = description.content { - subs.mark_tag_union_recursive(recursive, tags, ext_var); - } else { - // walk the chain till we find a tag union - for v in &chain[..chain.len() - 1] { - let description = subs.get(*v); - if let Content::Structure(FlatType::TagUnion(tags, ext_var)) = description.content { - subs.mark_tag_union_recursive(*v, tags, ext_var); - continue 'outer; - } + match description.content { + Content::Structure(FlatType::TagUnion(tags, ext_var)) => { + subs.mark_tag_union_recursive(recursive, tags, ext_var); } + LambdaSet(self::LambdaSet { + solved, + recursion_var: OptVariable::NONE, + }) => { + subs.mark_lambda_set_recursive(recursive, solved); + } + _ => { + // walk the chain till we find a tag union or lambda set + for v in &chain[..chain.len() - 1] { + let description = subs.get(*v); + match description.content { + Content::Structure(FlatType::TagUnion(tags, ext_var)) => { + subs.mark_tag_union_recursive(*v, tags, ext_var); + continue 'outer; + } + LambdaSet(self::LambdaSet { + solved, + recursion_var: OptVariable::NONE, + }) => { + subs.mark_lambda_set_recursive(recursive, solved); + continue 'outer; + } + _ => { /* fall through */ } + } + } - // Might not be any tag union if we only pass through `Apply`s. Otherwise, we have a bug! - if chain.iter().all(|&v| { - matches!( - subs.get_content_without_compacting(v), - Content::Structure(FlatType::Apply(..)) - ) - }) { - return; - } else { - internal_error!("recursive loop does not contain a tag union") + // Might not be any tag union if we only pass through `Apply`s. Otherwise, we have a bug! + if chain.iter().all(|&v| { + matches!( + subs.get_content_without_compacting(v), + Content::Structure(FlatType::Apply(..)) + ) + }) { + return; + } else { + internal_error!("recursive loop does not contain a tag union") + } } } } @@ -1563,8 +1625,8 @@ fn unify_shared_tags_new( // since we're expanding tag unions to equal depths as described above, // we'll always pass through this branch. So, we promote tag unions to recursive // ones here if it turns out they are that. - maybe_mark_tag_union_recursive(subs, actual); - maybe_mark_tag_union_recursive(subs, expected); + maybe_mark_union_recursive(subs, actual); + maybe_mark_union_recursive(subs, expected); let mut outcome = Outcome::default(); @@ -2094,7 +2156,8 @@ fn unify_recursion( ), LambdaSet(..) => { - mismatch!("RecursionVar {:?} with LambdaSet {:?}", ctx.first, &other) + // suppose that the recursion var is a lambda set + unify_pool(subs, pool, structure, ctx.second, ctx.mode) } Error => merge(subs, ctx, Error), @@ -2172,7 +2235,10 @@ fn unify_function_or_tag_union_and_func( { let tag_name = TagName::Closure(tag_symbol); let union_tags = UnionTags::tag_without_arguments(subs, tag_name); - let lambda_set_content = LambdaSet(self::LambdaSet { solved: union_tags }); + let lambda_set_content = LambdaSet(self::LambdaSet { + solved: union_tags, + recursion_var: OptVariable::NONE, + }); let tag_lambda_set = register( subs,