Instantiate recursive aliases to their smallest closures

Now, when we have two aliases like

```
T a : [ A, B (U a) ]
U a : [ C, D (T a) ]
```

during the first pass, we simply canonicalize them but add neither to
the scope. This means that `T` will not be instantiated in the
definition of `U`. Only in the second pass, during correction, do we
instantiate both aliases **independently**:

```
T a : [ A, B [ C, D (T a) ] ]
U a : [ C, D [ A, B (U a) ] ]
```

and now we can mark each recursive, individually:

```
T a : [ A, B [ C, D <rec1> ] ] as <rec1>
U a : [ C, D [ A, B <rec2> ] ] as <rec2>
```

This means that the surface types shown to users might be a bit larger,
but it has the benefit that everything needed to understand a layout of
a type in later passes is stored on the type directly, and we don't need
to keep alias mappings.

Since we sort by connected components, this should be complete.

Closes #2458
This commit is contained in:
ayazhafiz 2022-02-11 08:43:33 -05:00
parent c064c50036
commit 8c0e39211d
5 changed files with 200 additions and 131 deletions

View file

@ -9,6 +9,7 @@ use crate::expr::{
}; };
use crate::pattern::{bindings_from_patterns, canonicalize_pattern, Pattern}; use crate::pattern::{bindings_from_patterns, canonicalize_pattern, Pattern};
use crate::procedure::References; use crate::procedure::References;
use crate::scope::create_alias;
use crate::scope::Scope; use crate::scope::Scope;
use roc_collections::all::{default_hasher, ImMap, ImSet, MutMap, MutSet, SendMap}; use roc_collections::all::{default_hasher, ImMap, ImSet, MutMap, MutSet, SendMap};
use roc_module::ident::Lowercase; use roc_module::ident::Lowercase;
@ -302,14 +303,21 @@ pub fn canonicalize_defs<'a>(
continue; continue;
} }
scope.add_alias(symbol, name.region, can_vars.clone(), can_ann.typ.clone()); let alias = create_alias(symbol, name.region, can_vars.clone(), can_ann.typ.clone());
let alias = scope.lookup_alias(symbol).expect("alias is added to scope");
aliases.insert(symbol, alias.clone()); aliases.insert(symbol, alias.clone());
} }
// Now that we know the alias dependency graph, we can try to insert recursion variables // Now that we know the alias dependency graph, we can try to insert recursion variables
// where aliases are recursive tag unions, or detect illegal recursions. // where aliases are recursive tag unions, or detect illegal recursions.
correct_mutual_recursive_type_alias(env, &mut scope, &mut aliases, var_store); let mut aliases = correct_mutual_recursive_type_alias(env, &aliases, var_store);
for (symbol, alias) in aliases.iter() {
scope.add_alias(
*symbol,
alias.region,
alias.type_variables.clone(),
alias.typ.clone(),
);
}
// Now that we have the scope completely assembled, and shadowing resolved, // Now that we have the scope completely assembled, and shadowing resolved,
// we're ready to canonicalize any body exprs. // we're ready to canonicalize any body exprs.
@ -1534,18 +1542,17 @@ fn pending_typed_body<'a>(
/// Make aliases recursive /// Make aliases recursive
fn correct_mutual_recursive_type_alias<'a>( fn correct_mutual_recursive_type_alias<'a>(
env: &mut Env<'a>, env: &mut Env<'a>,
scope: &mut Scope, original_aliases: &SendMap<Symbol, Alias>,
aliases: &mut SendMap<Symbol, Alias>,
var_store: &mut VarStore, var_store: &mut VarStore,
) { ) -> SendMap<Symbol, Alias> {
let mut symbols_introduced = ImSet::default(); let mut symbols_introduced = ImSet::default();
for (key, _) in aliases.iter() { for (key, _) in original_aliases.iter() {
symbols_introduced.insert(*key); symbols_introduced.insert(*key);
} }
let all_successors_with_self = |symbol: &Symbol| -> ImSet<Symbol> { let all_successors_with_self = |symbol: &Symbol| -> ImSet<Symbol> {
match aliases.get(symbol) { match original_aliases.get(symbol) {
Some(alias) => { Some(alias) => {
let mut loc_succ = alias.typ.symbols(); let mut loc_succ = alias.typ.symbols();
// remove anything that is not defined in the current block // remove anything that is not defined in the current block
@ -1558,111 +1565,80 @@ fn correct_mutual_recursive_type_alias<'a>(
}; };
// TODO investigate should this be in a loop? // TODO investigate should this be in a loop?
let defined_symbols: Vec<Symbol> = aliases.keys().copied().collect(); let defined_symbols: Vec<Symbol> = original_aliases.keys().copied().collect();
let cycles = strongly_connected_components(&defined_symbols, all_successors_with_self); let cycles = strongly_connected_components(&defined_symbols, all_successors_with_self);
let mut solved_aliases = SendMap::default();
'next_cycle: for cycle in cycles { for cycle in cycles {
debug_assert!(!cycle.is_empty()); debug_assert!(!cycle.is_empty());
let mut pending_aliases: SendMap<_, _> = cycle
.iter()
.map(|&sym| (sym, original_aliases.get(&sym).unwrap().clone()))
.collect();
// Make sure we report only one error for the cycle, not an error for every // Make sure we report only one error for the cycle, not an error for every
// alias in the cycle. // alias in the cycle.
let mut can_still_report_error = true; let mut can_still_report_error = true;
// Go through and mark every self- and mutually-recursive alias cycle recursive.
if cycle.len() == 1 {
let symbol = cycle[0];
let alias = aliases.get_mut(&symbol).unwrap();
if !alias.typ.contains_symbol(symbol) {
// This alias is neither self nor mutually recursive.
continue 'next_cycle;
}
// This is a self-recursive cycle.
let mut can_still_report_error = true;
let mut opt_rec_var = None;
let _made_recursive = make_tag_union_of_alias_recursive(
env,
symbol,
alias,
vec![],
var_store,
&mut can_still_report_error,
&mut opt_rec_var,
);
scope.add_alias(
symbol,
alias.region,
alias.type_variables.clone(),
alias.typ.clone(),
);
} else {
// This is a mutually recursive cycle.
let mut opt_rec_var = None;
// First mark everything in the cycle recursive, as it needs to be.
for symbol in cycle.iter() {
let alias = aliases.get_mut(&symbol).unwrap();
let _made_recursive = make_tag_union_of_alias_recursive(
env,
*symbol,
alias,
vec![],
var_store,
&mut can_still_report_error,
&mut opt_rec_var,
);
}
// Now go through and instantiate references that are recursive, but we didn't know
// they were until now.
//
// TODO use itertools to be more efficient here
for &rec in cycle.iter() { for &rec in cycle.iter() {
let mut to_instantiate = ImMap::default(); // First, we need to instantiate the alias with any symbols in the currrent module it
let mut others = Vec::with_capacity(cycle.len() - 1); // depends on.
// We only need to worry about symbols in this SCC or any prior one, since the SCCs
// were sorted topologically, and we've already instantiated aliases coming from other
// modules.
let mut to_instantiate: ImMap<_, _> = solved_aliases.clone().into_iter().collect();
let mut others_in_scc = Vec::with_capacity(cycle.len() - 1);
for &other in cycle.iter() { for &other in cycle.iter() {
if rec != other { if rec != other {
others.push(other); others_in_scc.push(other);
if let Some(alias) = aliases.get(&other) { if let Some(alias) = original_aliases.get(&other) {
to_instantiate.insert(other, alias.clone()); to_instantiate.insert(other, alias.clone());
} }
} }
} }
if let Some(alias) = aliases.get_mut(&rec) { let alias = pending_aliases.get_mut(&rec).unwrap();
alias.typ.instantiate_aliases( alias.typ.instantiate_aliases(
alias.region, alias.region,
&to_instantiate, &to_instantiate,
var_store, var_store,
&mut ImSet::default(), &mut ImSet::default(),
); );
}
// Now mark the alias recursive, if it needs to be.
let is_self_recursive = alias.typ.contains_symbol(rec);
let is_mutually_recursive = cycle.len() > 1;
if is_self_recursive || is_mutually_recursive {
let mut opt_rec_var = None;
let _made_recursive = make_tag_union_of_alias_recursive(
env,
rec,
alias,
vec![],
var_store,
&mut can_still_report_error,
&mut opt_rec_var,
);
} }
} }
// The cycle we just marked recursive and instantiated may still be illegal cycles, if // The cycle we just instantiated and marked recursive may still be an illegal cycle, if
// all the types in the cycle are narrow newtypes. We can't figure this out until now, // all the types in the cycle are narrow newtypes. We can't figure this out until now,
// because we need all the types to be deeply instantiated. // because we need all the types to be deeply instantiated.
let all_are_narrow = cycle.iter().all(|sym| { let all_are_narrow = cycle.iter().all(|sym| {
let typ = &aliases.get(sym).unwrap().typ; let typ = &pending_aliases.get(sym).unwrap().typ;
typ.is_tag_union_like() && typ.is_narrow() typ.is_tag_union_like() && typ.is_narrow()
}); });
if !all_are_narrow {
// We pass through at least one tag union that has a non-recursive variant, so this
// cycle is legal.
continue 'next_cycle;
}
if all_are_narrow {
// This cycle is illegal!
let mut rest = cycle; let mut rest = cycle;
let alias_name = rest.pop().unwrap(); let alias_name = rest.pop().unwrap();
let alias = aliases.get_mut(&alias_name).unwrap(); let alias = pending_aliases.get_mut(&alias_name).unwrap();
mark_cyclic_alias( mark_cyclic_alias(
env, env,
@ -1673,6 +1649,12 @@ fn correct_mutual_recursive_type_alias<'a>(
can_still_report_error, can_still_report_error,
) )
} }
// Now, promote all resolved aliases in this cycle as solved.
solved_aliases.extend(pending_aliases);
}
solved_aliases
} }
fn make_tag_union_of_alias_recursive<'a>( fn make_tag_union_of_alias_recursive<'a>(

View file

@ -17,7 +17,7 @@ pub struct Scope {
symbols: SendMap<Symbol, Region>, symbols: SendMap<Symbol, Region>,
/// The type aliases currently in scope /// The type aliases currently in scope
aliases: SendMap<Symbol, Alias>, pub aliases: SendMap<Symbol, Alias>,
/// The current module being processed. This will be used to turn /// The current module being processed. This will be used to turn
/// unqualified idents into Symbols. /// unqualified idents into Symbols.
@ -181,6 +181,21 @@ impl Scope {
vars: Vec<Loc<(Lowercase, Variable)>>, vars: Vec<Loc<(Lowercase, Variable)>>,
typ: Type, typ: Type,
) { ) {
let alias = create_alias(name, region, vars, typ);
self.aliases.insert(name, alias);
}
pub fn contains_alias(&mut self, name: Symbol) -> bool {
self.aliases.contains_key(&name)
}
}
pub fn create_alias(
name: Symbol,
region: Region,
vars: Vec<Loc<(Lowercase, Variable)>>,
typ: Type,
) -> Alias {
let roc_types::types::VariableDetail { let roc_types::types::VariableDetail {
type_variables, type_variables,
lambda_set_variables, lambda_set_variables,
@ -209,18 +224,11 @@ impl Scope {
.map(|v| roc_types::types::LambdaSet(Type::Variable(v))) .map(|v| roc_types::types::LambdaSet(Type::Variable(v)))
.collect(); .collect();
let alias = Alias { Alias {
region, region,
type_variables: vars, type_variables: vars,
lambda_set_variables, lambda_set_variables,
recursion_variables, recursion_variables,
typ, typ,
};
self.aliases.insert(name, alias);
}
pub fn contains_alias(&mut self, name: Symbol) -> bool {
self.aliases.contains_key(&name)
} }
} }

View file

@ -3036,7 +3036,6 @@ mod solve_expr {
} }
#[test] #[test]
#[ignore]
fn typecheck_mutually_recursive_tag_union_2() { fn typecheck_mutually_recursive_tag_union_2() {
infer_eq_without_problem( infer_eq_without_problem(
indoc!( indoc!(
@ -3064,7 +3063,6 @@ mod solve_expr {
} }
#[test] #[test]
#[ignore]
fn typecheck_mutually_recursive_tag_union_listabc() { fn typecheck_mutually_recursive_tag_union_listabc() {
infer_eq_without_problem( infer_eq_without_problem(
indoc!( indoc!(
@ -5196,4 +5194,22 @@ mod solve_expr {
r#"{ bi128 : I128 -> I128, bi16 : I16 -> I16, bi32 : I32 -> I32, bi64 : I64 -> I64, bi8 : I8 -> I8, bnat : Nat -> Nat, bu128 : U128 -> U128, bu16 : U16 -> U16, bu32 : U32 -> U32, bu64 : U64 -> U64, bu8 : U8 -> U8, dec : Dec -> Dec, f32 : F32 -> F32, f64 : F64 -> F64, fdec : Dec -> Dec, ff32 : F32 -> F32, ff64 : F64 -> F64, i128 : I128 -> I128, i16 : I16 -> I16, i32 : I32 -> I32, i64 : I64 -> I64, i8 : I8 -> I8, nat : Nat -> Nat, u128 : U128 -> U128, u16 : U16 -> U16, u32 : U32 -> U32, u64 : U64 -> U64, u8 : U8 -> U8 }"#, r#"{ bi128 : I128 -> I128, bi16 : I16 -> I16, bi32 : I32 -> I32, bi64 : I64 -> I64, bi8 : I8 -> I8, bnat : Nat -> Nat, bu128 : U128 -> U128, bu16 : U16 -> U16, bu32 : U32 -> U32, bu64 : U64 -> U64, bu8 : U8 -> U8, dec : Dec -> Dec, f32 : F32 -> F32, f64 : F64 -> F64, fdec : Dec -> Dec, ff32 : F32 -> F32, ff64 : F64 -> F64, i128 : I128 -> I128, i16 : I16 -> I16, i32 : I32 -> I32, i64 : I64 -> I64, i8 : I8 -> I8, nat : Nat -> Nat, u128 : U128 -> U128, u16 : U16 -> U16, u32 : U32 -> U32, u64 : U64 -> U64, u8 : U8 -> U8 }"#,
) )
} }
#[test]
fn issue_2458() {
infer_eq_without_problem(
indoc!(
r#"
Foo a : [ Blah (Result (Bar a) { val: a }) ]
Bar a : Foo a
v : Bar U8
v = Blah (Ok (Blah (Err { val: 1 })))
v
"#
),
"Bar U8",
)
}
} }

View file

@ -1474,3 +1474,47 @@ fn issue_2445() {
i64 i64
); );
} }
#[test]
#[cfg(any(feature = "gen-llvm"))]
fn issue_2458() {
assert_evals_to!(
indoc!(
r#"
Foo a : [ Blah (Bar a), Nothing {} ]
Bar a : Foo a
v : Bar {}
v = Blah (Blah (Nothing {}))
when v is
Blah (Blah (Nothing {})) -> 15
_ -> 25
"#
),
15,
u8
)
}
#[test]
#[ignore = "See https://github.com/rtfeldman/roc/issues/2466"]
#[cfg(any(feature = "gen-llvm"))]
fn issue_2458_deep_recursion_var() {
assert_evals_to!(
indoc!(
r#"
Foo a : [ Blah (Result (Bar a) {}) ]
Bar a : Foo a
v : Bar {}
when v is
Blah (Ok (Blah (Err {}))) -> "1"
_ -> "2"
"#
),
15,
u8
)
}

View file

@ -3348,6 +3348,8 @@ mod test_reporting {
{ x, y } { x, y }
"# "#
), ),
// TODO render tag unions across multiple lines
// TODO do not show recursion var if the recursion var does not render on the surface of a type
indoc!( indoc!(
r#" r#"
TYPE MISMATCH TYPE MISMATCH
@ -3360,12 +3362,13 @@ mod test_reporting {
This `ACons` global tag application has the type: This `ACons` global tag application has the type:
[ ACons Num (Integer Signed64) [ BCons (Num a) [ ACons Str [ BNil [ ACons (Num (Integer Signed64)) [
]b ]c ]d, ANil ] BCons (Num (Integer Signed64)) [ ACons Str [ BCons I64 a, BNil ],
ANil ], BNil ], ANil ]
But the type annotation on `x` says it should be: But the type annotation on `x` says it should be:
[ ACons I64 BList I64 I64, ANil ] [ ACons I64 (BList I64 I64), ANil ] as a
"# "#
), ),
) )
@ -8068,4 +8071,20 @@ I need all branches in an `if` to have the same type!
), ),
) )
} }
#[test]
fn issue_2458() {
report_problem_as(
indoc!(
r#"
Foo a : [ Blah (Result (Bar a) []) ]
Bar a : Foo a
v : Bar U8
v
"#
),
"",
)
}
} }