Use fixpoint-fixing in unification

This commit is contained in:
Ayaz Hafiz 2022-11-15 09:54:32 -06:00
parent 9a7402f40b
commit 5a92947326
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
2 changed files with 170 additions and 45 deletions

View file

@ -1,5 +1,5 @@
use bitflags::bitflags;
use roc_collections::VecMap;
use roc_collections::{VecMap, VecSet};
use roc_debug_flags::dbg_do;
#[cfg(debug_assertions)]
use roc_debug_flags::{ROC_PRINT_MISMATCHES, ROC_PRINT_UNIFICATIONS};
@ -321,6 +321,8 @@ impl<M: MetaCollector> Outcome<M> {
pub struct Env<'a> {
pub subs: &'a mut Subs,
compute_outcome_only: bool,
seen_recursion: VecSet<(Variable, Variable)>,
fixed_variables: VecSet<Variable>,
}
impl<'a> Env<'a> {
@ -328,6 +330,8 @@ impl<'a> Env<'a> {
Self {
subs,
compute_outcome_only: false,
seen_recursion: Default::default(),
fixed_variables: Default::default(),
}
}
@ -339,6 +343,48 @@ impl<'a> Env<'a> {
self.compute_outcome_only = false;
result
}
fn add_recursion_pair(&mut self, var1: Variable, var2: Variable) {
let pair = (
self.subs.get_root_key_without_compacting(var1),
self.subs.get_root_key_without_compacting(var2),
);
let already_seen = self.seen_recursion.insert(pair);
debug_assert!(!already_seen);
}
fn remove_recursion_pair(&mut self, var1: Variable, var2: Variable) {
#[cfg(debug_assertions)]
let size_before = self.seen_recursion.len();
self.seen_recursion.retain(|(v1, v2)| {
let is_recursion_pair = self.subs.equivalent_without_compacting(*v1, var1)
&& self.subs.equivalent_without_compacting(*v2, var2);
!is_recursion_pair
});
#[cfg(debug_assertions)]
let size_after = self.seen_recursion.len();
#[cfg(debug_assertions)]
debug_assert!(size_after < size_before, "nothing was removed");
}
fn seen_recursion_pair(&mut self, var1: Variable, var2: Variable) -> bool {
let (var1, var2) = (
self.subs.get_root_key_without_compacting(var1),
self.subs.get_root_key_without_compacting(var2),
);
self.seen_recursion.contains(&(var1, var2))
}
fn was_fixed(&self, var: Variable) -> bool {
self.fixed_variables
.iter()
.any(|fixed_var| self.subs.equivalent_without_compacting(*fixed_var, var))
}
}
/// Unifies two types.
@ -863,6 +909,12 @@ fn unify_two_aliases<M: MetaCollector>(
}
}
fn fix_fixpoint<M: MetaCollector>(env: &mut Env, ctx: &Context) -> Outcome<M> {
let fixed_variables = crate::fix::fix_fixpoint(env.subs, ctx.first, ctx.second);
env.fixed_variables.extend(fixed_variables);
Default::default()
}
// Unifies a structural alias
#[inline(always)]
#[must_use]
@ -883,7 +935,17 @@ fn unify_alias<M: MetaCollector>(
// Alias wins
merge(env, ctx, Alias(symbol, args, real_var, kind))
}
RecursionVar { structure, .. } => unify_pool(env, pool, real_var, *structure, ctx.mode),
RecursionVar { structure, .. } => {
if env.seen_recursion_pair(ctx.first, ctx.second) {
return fix_fixpoint(env, ctx);
}
env.add_recursion_pair(ctx.first, ctx.second);
let outcome = unify_pool(env, pool, real_var, *structure, ctx.mode);
env.remove_recursion_pair(ctx.first, ctx.second);
outcome
}
RigidVar(_) | RigidAbleVar(..) | FlexAbleVar(..) => {
unify_pool(env, pool, real_var, ctx.second, ctx.mode)
}
@ -956,7 +1018,17 @@ fn unify_opaque<M: MetaCollector>(
Alias(_, _, other_real_var, AliasKind::Structural) => {
unify_pool(env, pool, ctx.first, *other_real_var, ctx.mode)
}
RecursionVar { structure, .. } => unify_pool(env, pool, ctx.first, *structure, ctx.mode),
RecursionVar { structure, .. } => {
if env.seen_recursion_pair(ctx.first, ctx.second) {
return fix_fixpoint(env, ctx);
}
env.add_recursion_pair(ctx.first, ctx.second);
let outcome = unify_pool(env, pool, real_var, *structure, ctx.mode);
env.remove_recursion_pair(ctx.first, ctx.second);
outcome
}
Alias(other_symbol, other_args, other_real_var, AliasKind::Opaque) => {
// Opaques types are only equal if the opaque symbols are equal!
if symbol == *other_symbol {
@ -1030,27 +1102,45 @@ fn unify_structure<M: MetaCollector>(
&other
)
}
RecursionVar { structure, .. } => match flat_type {
FlatType::TagUnion(_, _) => {
// unify the structure with this unrecursive tag union
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
RecursionVar { structure, .. } => {
if env.seen_recursion_pair(ctx.first, ctx.second) {
return fix_fixpoint(env, ctx);
}
FlatType::RecursiveTagUnion(rec, _, _) => {
debug_assert!(is_recursion_var(env.subs, *rec));
// unify the structure with this recursive tag union
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
}
FlatType::FunctionOrTagUnion(_, _, _) => {
// unify the structure with this unrecursive tag union
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
}
// Only tag unions can be recursive; everything else is an error.
_ => mismatch!(
"trying to unify {:?} with recursive type var {:?}",
&flat_type,
structure
),
},
env.add_recursion_pair(ctx.first, ctx.second);
let outcome = match flat_type {
FlatType::TagUnion(_, _) => {
// unify the structure with this unrecursive tag union
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
}
FlatType::RecursiveTagUnion(rec, _, _) => {
debug_assert!(
is_recursion_var(env.subs, *rec),
"{:?}",
roc_types::subs::SubsFmtContent(
env.subs.get_content_without_compacting(*rec),
env.subs
)
);
// unify the structure with this recursive tag union
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
}
FlatType::FunctionOrTagUnion(_, _, _) => {
// unify the structure with this unrecursive tag union
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
}
// Only tag unions can be recursive; everything else is an error.
_ => mismatch!(
"trying to unify {:?} with recursive type var {:?}",
&flat_type,
structure
),
};
env.remove_recursion_pair(ctx.first, ctx.second);
outcome
}
Structure(ref other_flat_type) => {
// Unify the two flat types
@ -1121,8 +1211,16 @@ fn unify_lambda_set<M: MetaCollector>(
}
}
RecursionVar { structure, .. } => {
if env.seen_recursion_pair(ctx.first, ctx.second) {
return fix_fixpoint(env, ctx);
}
env.add_recursion_pair(ctx.first, ctx.second);
// suppose that the recursion var is a lambda set
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
let outcome = unify_pool(env, pool, ctx.first, *structure, ctx.mode);
env.remove_recursion_pair(ctx.first, ctx.second);
outcome
}
RigidVar(..) | RigidAbleVar(..) => mismatch!("Lambda sets never unify with rigid"),
FlexAbleVar(..) => mismatch!("Lambda sets should never have abilities attached to them"),
@ -2701,6 +2799,11 @@ fn unify_shared_tags_merge_new<M: MetaCollector>(
new_ext_var: Variable,
recursion_var: Rec,
) -> Outcome<M> {
let was_fixed = env.was_fixed(ctx.first) || env.was_fixed(ctx.second);
if was_fixed {
return Default::default();
}
let flat_type = match recursion_var {
Rec::None => FlatType::TagUnion(new_tags, new_ext_var),
Rec::Left(rec) | Rec::Right(rec) | Rec::Both(rec, _) => {
@ -2770,8 +2873,16 @@ fn unify_flat_type<M: MetaCollector>(
}
(RecursiveTagUnion(rec1, tags1, ext1), RecursiveTagUnion(rec2, tags2, ext2)) => {
debug_assert!(is_recursion_var(env.subs, *rec1));
debug_assert!(is_recursion_var(env.subs, *rec2));
debug_assert!(
is_recursion_var(env.subs, *rec1),
"{:?}",
env.subs.dbg(*rec1)
);
debug_assert!(
is_recursion_var(env.subs, *rec2),
"{:?}",
env.subs.dbg(*rec2)
);
let rec = Rec::Both(*rec1, *rec2);
let mut outcome = unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec);
@ -3240,7 +3351,15 @@ fn unify_recursion<M: MetaCollector>(
structure: Variable,
other: &Content,
) -> Outcome<M> {
match other {
if !matches!(other, RecursionVar { .. }) {
if env.seen_recursion_pair(ctx.first, ctx.second) {
return Default::default();
}
env.add_recursion_pair(ctx.first, ctx.second);
}
let outcome = match other {
RecursionVar {
opt_name: other_opt_name,
structure: _other_structure,
@ -3315,7 +3434,13 @@ fn unify_recursion<M: MetaCollector>(
}
Error => merge(env, ctx, Error),
};
if !matches!(other, RecursionVar { .. }) {
env.remove_recursion_pair(ctx.first, ctx.second);
}
outcome
}
#[must_use]