From b61481c6e771f6c3a1969cf00d0d904a9ca852b5 Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Fri, 8 Apr 2022 17:05:57 -0400 Subject: [PATCH] Try another strategy - fix recursion vars during typechecking --- compiler/mono/src/layout.rs | 8 ++-- compiler/types/src/subs.rs | 74 +++++++++++++++++++++++++++--------- compiler/unify/src/unify.rs | 76 ++++++++++++++++++++++++++++++++++--- 3 files changed, 131 insertions(+), 27 deletions(-) diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index e21d054ed0..58fe798d23 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -196,7 +196,7 @@ impl<'a> RawFunctionLayout<'a> { Self::from_var(env, var) } _ => { - let layout = layout_from_flat_type(env, var, flat_type)?; + let layout = layout_from_flat_type(env, flat_type)?; Ok(Self::ZeroArgumentThunk(layout)) } } @@ -960,7 +960,7 @@ impl<'a> Layout<'a> { let structure_content = env.subs.get_content_without_compacting(structure); Self::new_help(env, structure, *structure_content) } - Structure(flat_type) => layout_from_flat_type(env, var, flat_type), + Structure(flat_type) => layout_from_flat_type(env, flat_type), Alias(symbol, _args, actual_var, _) => { if let Some(int_width) = IntWidth::try_from_symbol(symbol) { @@ -1573,7 +1573,6 @@ impl<'a> Builtin<'a> { fn layout_from_flat_type<'a>( env: &mut Env<'a, '_>, - var: Variable, flat_type: FlatType, ) -> Result, LayoutProblem> { use roc_types::subs::FlatType::*; @@ -1776,7 +1775,6 @@ fn layout_from_flat_type<'a>( } } - env.insert_seen(var); env.insert_seen(rec_var); for (index, &(_name, variables)) in tags_vec.iter().enumerate() { if matches!(nullable, Some(i) if i == index as TagIdIntType) { @@ -1793,6 +1791,7 @@ fn layout_from_flat_type<'a>( continue; } + let content = subs.get_content_without_compacting(var); tag_layout.push(Layout::from_var(env, var)?); } @@ -1806,7 +1805,6 @@ fn layout_from_flat_type<'a>( tag_layouts.push(tag_layout.into_bump_slice()); } env.remove_seen(rec_var); - env.insert_seen(var); let union_layout = if let Some(tag_id) = nullable { match tag_layouts.into_bump_slice() { diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index a94cca1be6..256e1b7cef 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -1906,7 +1906,14 @@ impl Subs { } pub fn occurs(&self, var: Variable) -> Result<(), (Variable, Vec)> { - occurs(self, &[], var) + occurs(self, &[], var, false) + } + + pub fn occurs_including_recursion_vars( + &self, + var: Variable, + ) -> Result<(), (Variable, Vec)> { + occurs(self, &[], var, true) } pub fn mark_tag_union_recursive( @@ -2876,6 +2883,7 @@ fn occurs( subs: &Subs, seen: &[Variable], input_var: Variable, + include_recursion_var: bool, ) -> Result<(), (Variable, Vec)> { use self::Content::*; use self::FlatType::*; @@ -2899,47 +2907,77 @@ fn occurs( new_seen.push(root_var); match flat_type { - Apply(_, args) => { - short_circuit(subs, root_var, &new_seen, subs.get_subs_slice(*args).iter()) - } + Apply(_, args) => short_circuit( + subs, + root_var, + &new_seen, + subs.get_subs_slice(*args).iter(), + include_recursion_var, + ), Func(arg_vars, closure_var, ret_var) => { let it = once(ret_var) .chain(once(closure_var)) .chain(subs.get_subs_slice(*arg_vars).iter()); - short_circuit(subs, root_var, &new_seen, it) + short_circuit(subs, root_var, &new_seen, it, include_recursion_var) } Record(vars_by_field, ext_var) => { let slice = SubsSlice::new(vars_by_field.variables_start, vars_by_field.length); let it = once(ext_var).chain(subs.get_subs_slice(slice).iter()); - short_circuit(subs, root_var, &new_seen, it) + short_circuit(subs, root_var, &new_seen, it, include_recursion_var) } TagUnion(tags, ext_var) => { for slice_index in tags.variables() { let slice = subs[slice_index]; for var_index in slice { let var = subs[var_index]; - short_circuit_help(subs, root_var, &new_seen, var)?; + short_circuit_help( + subs, + root_var, + &new_seen, + var, + include_recursion_var, + )?; } } - short_circuit_help(subs, root_var, &new_seen, *ext_var) + short_circuit_help( + subs, + root_var, + &new_seen, + *ext_var, + include_recursion_var, + ) } FunctionOrTagUnion(_, _, ext_var) => { let it = once(ext_var); - short_circuit(subs, root_var, &new_seen, it) + short_circuit(subs, root_var, &new_seen, it, include_recursion_var) } - RecursiveTagUnion(_rec_var, tags, ext_var) => { - // TODO rec_var is excluded here, verify that this is correct + RecursiveTagUnion(rec_var, tags, ext_var) => { + if include_recursion_var { + new_seen.push(*rec_var); + } for slice_index in tags.variables() { let slice = subs[slice_index]; for var_index in slice { let var = subs[var_index]; - short_circuit_help(subs, root_var, &new_seen, var)?; + short_circuit_help( + subs, + root_var, + &new_seen, + var, + include_recursion_var, + )?; } } - short_circuit_help(subs, root_var, &new_seen, *ext_var) + short_circuit_help( + subs, + root_var, + &new_seen, + *ext_var, + include_recursion_var, + ) } EmptyRecord | EmptyTagUnion | Erroneous(_) => Ok(()), } @@ -2950,7 +2988,7 @@ fn occurs( for var_index in args.into_iter() { let var = subs[var_index]; - short_circuit_help(subs, root_var, &new_seen, var)?; + short_circuit_help(subs, root_var, &new_seen, var, include_recursion_var)?; } Ok(()) @@ -2959,7 +2997,7 @@ fn occurs( let mut new_seen = seen.to_owned(); new_seen.push(root_var); - short_circuit_help(subs, root_var, &new_seen, *typ)?; + short_circuit_help(subs, root_var, &new_seen, *typ, include_recursion_var)?; // _range_vars excluded because they are not explicitly part of the type. Ok(()) @@ -2974,12 +3012,13 @@ fn short_circuit<'a, T>( root_key: Variable, seen: &[Variable], iter: T, + include_recursion_var: bool, ) -> Result<(), (Variable, Vec)> where T: Iterator, { for var in iter { - short_circuit_help(subs, root_key, seen, *var)?; + short_circuit_help(subs, root_key, seen, *var, include_recursion_var)?; } Ok(()) @@ -2991,8 +3030,9 @@ fn short_circuit_help( root_key: Variable, seen: &[Variable], var: Variable, + include_recursion_var: bool, ) -> Result<(), (Variable, Vec)> { - if let Err((v, mut vec)) = occurs(subs, seen, var) { + if let Err((v, mut vec)) = occurs(subs, seen, var, include_recursion_var) { vec.push(root_key); return Err((v, vec)); } diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index 10738f6382..530db674c8 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -5,7 +5,8 @@ use roc_module::symbol::Symbol; use roc_types::subs::Content::{self, *}; use roc_types::subs::{ AliasVariables, Descriptor, ErrorTypeContext, FlatType, GetSubsSlice, Mark, OptVariable, - RecordFields, Subs, SubsIndex, SubsSlice, UnionTags, Variable, VariableSubsSlice, + RecordFields, Subs, SubsFmtContent, SubsIndex, SubsSlice, UnionTags, Variable, + VariableSubsSlice, }; use roc_types::types::{AliasKind, DoesNotImplementAbility, ErrorType, Mismatch, RecordField}; @@ -316,10 +317,10 @@ fn debug_print_unified_types(subs: &mut Subs, ctx: &Context, opt_outcome: Option ctx.first, ctx.second, ctx.first, - roc_types::subs::SubsFmtContent(&content_1, subs), + SubsFmtContent(&content_1, subs), mode, ctx.second, - roc_types::subs::SubsFmtContent(&content_2, subs), + SubsFmtContent(&content_2, subs), ); unsafe { UNIFICATION_DEPTH = new_depth }; @@ -576,7 +577,13 @@ fn unify_structure( RecursionVar { structure, .. } => match flat_type { FlatType::TagUnion(_, _) => { // unify the structure with this unrecursive tag union - unify_pool(subs, pool, ctx.first, *structure, ctx.mode) + let mut problems = unify_pool(subs, pool, ctx.first, *structure, ctx.mode); + + problems.extend(fix_tag_union_recursion_variable( + subs, ctx, ctx.first, other, + )); + + problems } FlatType::RecursiveTagUnion(rec, _, _) => { debug_assert!(is_recursion_var(subs, *rec)); @@ -585,7 +592,13 @@ fn unify_structure( } FlatType::FunctionOrTagUnion(_, _, _) => { // unify the structure with this unrecursive tag union - unify_pool(subs, pool, ctx.first, *structure, ctx.mode) + let mut problems = unify_pool(subs, pool, ctx.first, *structure, ctx.mode); + + problems.extend(fix_tag_union_recursion_variable( + subs, ctx, ctx.first, other, + )); + + problems } // Only tag unions can be recursive; everything else is an error. _ => mismatch!( @@ -643,6 +656,59 @@ fn unify_structure( } } +/// Ensures that a non-recursive tag union, when unified with a recursion var to become a recursive +/// tag union, properly contains a recursion variable that recurses on itself. +// +// When might this not be the case? For example, in the code +// +// Indirect : [ Indirect ConsList ] +// +// ConsList : [ Nil, Cons Indirect ] +// +// l : ConsList +// l = Cons (Indirect (Cons (Indirect Nil))) +// # ^^^^^^^^^^^^^^^~~~~~~~~~~~~~~~~~~~~~^ region-a +// # ~~~~~~~~~~~~~~~~~~~~~ region-b +// l +// +// Suppose `ConsList` has the expanded type `[ Nil, Cons [ Indirect ] ] as `. +// After unifying the tag application annotated "region-b" with the recursion variable ``, +// the tentative total-type of the application annotated "region-a" would be +// ` = [ Nil, Cons [ Indirect ] ] as `. That is, the type of the recursive tag union +// would be inlined at the site "v", rather than passing through the correct recursion variable +// "rec" first. +// +// This is not incorrect from a type perspective, but causes problems later on for e.g. layout +// determination, which expects recursion variables to be placed correctly. Attempting to detect +// this during layout generation does not work so well because it may be that there *are* recursive +// tag unions that should be inlined, and not pass through recursion variables. So instead, try to +// resolve these cases here. +// +// See tests labeled "issue_2810" for more examples. +fn fix_tag_union_recursion_variable( + subs: &mut Subs, + ctx: &Context, + tag_union_promoted_to_recursive: Variable, + recursion_var: &Content, +) -> Outcome { + debug_assert!(matches!( + subs.get_content_without_compacting(tag_union_promoted_to_recursive), + Structure(FlatType::RecursiveTagUnion(..)) + )); + + let f = subs.get_content_without_compacting(tag_union_promoted_to_recursive); + + let has_recursing_recursive_variable = subs + .occurs_including_recursion_vars(tag_union_promoted_to_recursive) + .is_err(); + + if !has_recursing_recursive_variable { + merge(subs, ctx, recursion_var.clone()) + } else { + vec![] + } +} + fn unify_record( subs: &mut Subs, pool: &mut Pool,