mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-27 22:09:09 +00:00
Use fixpoint-fixing in unification
This commit is contained in:
parent
9a7402f40b
commit
5a92947326
2 changed files with 170 additions and 45 deletions
|
@ -1,5 +1,5 @@
|
|||
use bitflags::bitflags;
|
||||
use roc_collections::VecMap;
|
||||
use roc_collections::{VecMap, VecSet};
|
||||
use roc_debug_flags::dbg_do;
|
||||
#[cfg(debug_assertions)]
|
||||
use roc_debug_flags::{ROC_PRINT_MISMATCHES, ROC_PRINT_UNIFICATIONS};
|
||||
|
@ -321,6 +321,8 @@ impl<M: MetaCollector> Outcome<M> {
|
|||
pub struct Env<'a> {
|
||||
pub subs: &'a mut Subs,
|
||||
compute_outcome_only: bool,
|
||||
seen_recursion: VecSet<(Variable, Variable)>,
|
||||
fixed_variables: VecSet<Variable>,
|
||||
}
|
||||
|
||||
impl<'a> Env<'a> {
|
||||
|
@ -328,6 +330,8 @@ impl<'a> Env<'a> {
|
|||
Self {
|
||||
subs,
|
||||
compute_outcome_only: false,
|
||||
seen_recursion: Default::default(),
|
||||
fixed_variables: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -339,6 +343,48 @@ impl<'a> Env<'a> {
|
|||
self.compute_outcome_only = false;
|
||||
result
|
||||
}
|
||||
|
||||
fn add_recursion_pair(&mut self, var1: Variable, var2: Variable) {
|
||||
let pair = (
|
||||
self.subs.get_root_key_without_compacting(var1),
|
||||
self.subs.get_root_key_without_compacting(var2),
|
||||
);
|
||||
|
||||
let already_seen = self.seen_recursion.insert(pair);
|
||||
debug_assert!(!already_seen);
|
||||
}
|
||||
|
||||
fn remove_recursion_pair(&mut self, var1: Variable, var2: Variable) {
|
||||
#[cfg(debug_assertions)]
|
||||
let size_before = self.seen_recursion.len();
|
||||
|
||||
self.seen_recursion.retain(|(v1, v2)| {
|
||||
let is_recursion_pair = self.subs.equivalent_without_compacting(*v1, var1)
|
||||
&& self.subs.equivalent_without_compacting(*v2, var2);
|
||||
!is_recursion_pair
|
||||
});
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
let size_after = self.seen_recursion.len();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
debug_assert!(size_after < size_before, "nothing was removed");
|
||||
}
|
||||
|
||||
fn seen_recursion_pair(&mut self, var1: Variable, var2: Variable) -> bool {
|
||||
let (var1, var2) = (
|
||||
self.subs.get_root_key_without_compacting(var1),
|
||||
self.subs.get_root_key_without_compacting(var2),
|
||||
);
|
||||
|
||||
self.seen_recursion.contains(&(var1, var2))
|
||||
}
|
||||
|
||||
fn was_fixed(&self, var: Variable) -> bool {
|
||||
self.fixed_variables
|
||||
.iter()
|
||||
.any(|fixed_var| self.subs.equivalent_without_compacting(*fixed_var, var))
|
||||
}
|
||||
}
|
||||
|
||||
/// Unifies two types.
|
||||
|
@ -863,6 +909,12 @@ fn unify_two_aliases<M: MetaCollector>(
|
|||
}
|
||||
}
|
||||
|
||||
fn fix_fixpoint<M: MetaCollector>(env: &mut Env, ctx: &Context) -> Outcome<M> {
|
||||
let fixed_variables = crate::fix::fix_fixpoint(env.subs, ctx.first, ctx.second);
|
||||
env.fixed_variables.extend(fixed_variables);
|
||||
Default::default()
|
||||
}
|
||||
|
||||
// Unifies a structural alias
|
||||
#[inline(always)]
|
||||
#[must_use]
|
||||
|
@ -883,7 +935,17 @@ fn unify_alias<M: MetaCollector>(
|
|||
// Alias wins
|
||||
merge(env, ctx, Alias(symbol, args, real_var, kind))
|
||||
}
|
||||
RecursionVar { structure, .. } => unify_pool(env, pool, real_var, *structure, ctx.mode),
|
||||
RecursionVar { structure, .. } => {
|
||||
if env.seen_recursion_pair(ctx.first, ctx.second) {
|
||||
return fix_fixpoint(env, ctx);
|
||||
}
|
||||
|
||||
env.add_recursion_pair(ctx.first, ctx.second);
|
||||
let outcome = unify_pool(env, pool, real_var, *structure, ctx.mode);
|
||||
env.remove_recursion_pair(ctx.first, ctx.second);
|
||||
|
||||
outcome
|
||||
}
|
||||
RigidVar(_) | RigidAbleVar(..) | FlexAbleVar(..) => {
|
||||
unify_pool(env, pool, real_var, ctx.second, ctx.mode)
|
||||
}
|
||||
|
@ -956,7 +1018,17 @@ fn unify_opaque<M: MetaCollector>(
|
|||
Alias(_, _, other_real_var, AliasKind::Structural) => {
|
||||
unify_pool(env, pool, ctx.first, *other_real_var, ctx.mode)
|
||||
}
|
||||
RecursionVar { structure, .. } => unify_pool(env, pool, ctx.first, *structure, ctx.mode),
|
||||
RecursionVar { structure, .. } => {
|
||||
if env.seen_recursion_pair(ctx.first, ctx.second) {
|
||||
return fix_fixpoint(env, ctx);
|
||||
}
|
||||
|
||||
env.add_recursion_pair(ctx.first, ctx.second);
|
||||
let outcome = unify_pool(env, pool, real_var, *structure, ctx.mode);
|
||||
env.remove_recursion_pair(ctx.first, ctx.second);
|
||||
|
||||
outcome
|
||||
}
|
||||
Alias(other_symbol, other_args, other_real_var, AliasKind::Opaque) => {
|
||||
// Opaques types are only equal if the opaque symbols are equal!
|
||||
if symbol == *other_symbol {
|
||||
|
@ -1030,27 +1102,45 @@ fn unify_structure<M: MetaCollector>(
|
|||
&other
|
||||
)
|
||||
}
|
||||
RecursionVar { structure, .. } => match flat_type {
|
||||
FlatType::TagUnion(_, _) => {
|
||||
// unify the structure with this unrecursive tag union
|
||||
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
|
||||
RecursionVar { structure, .. } => {
|
||||
if env.seen_recursion_pair(ctx.first, ctx.second) {
|
||||
return fix_fixpoint(env, ctx);
|
||||
}
|
||||
FlatType::RecursiveTagUnion(rec, _, _) => {
|
||||
debug_assert!(is_recursion_var(env.subs, *rec));
|
||||
// unify the structure with this recursive tag union
|
||||
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
|
||||
}
|
||||
FlatType::FunctionOrTagUnion(_, _, _) => {
|
||||
// unify the structure with this unrecursive tag union
|
||||
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
|
||||
}
|
||||
// Only tag unions can be recursive; everything else is an error.
|
||||
_ => mismatch!(
|
||||
"trying to unify {:?} with recursive type var {:?}",
|
||||
&flat_type,
|
||||
structure
|
||||
),
|
||||
},
|
||||
|
||||
env.add_recursion_pair(ctx.first, ctx.second);
|
||||
|
||||
let outcome = match flat_type {
|
||||
FlatType::TagUnion(_, _) => {
|
||||
// unify the structure with this unrecursive tag union
|
||||
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
|
||||
}
|
||||
FlatType::RecursiveTagUnion(rec, _, _) => {
|
||||
debug_assert!(
|
||||
is_recursion_var(env.subs, *rec),
|
||||
"{:?}",
|
||||
roc_types::subs::SubsFmtContent(
|
||||
env.subs.get_content_without_compacting(*rec),
|
||||
env.subs
|
||||
)
|
||||
);
|
||||
// unify the structure with this recursive tag union
|
||||
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
|
||||
}
|
||||
FlatType::FunctionOrTagUnion(_, _, _) => {
|
||||
// unify the structure with this unrecursive tag union
|
||||
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
|
||||
}
|
||||
// Only tag unions can be recursive; everything else is an error.
|
||||
_ => mismatch!(
|
||||
"trying to unify {:?} with recursive type var {:?}",
|
||||
&flat_type,
|
||||
structure
|
||||
),
|
||||
};
|
||||
|
||||
env.remove_recursion_pair(ctx.first, ctx.second);
|
||||
outcome
|
||||
}
|
||||
|
||||
Structure(ref other_flat_type) => {
|
||||
// Unify the two flat types
|
||||
|
@ -1121,8 +1211,16 @@ fn unify_lambda_set<M: MetaCollector>(
|
|||
}
|
||||
}
|
||||
RecursionVar { structure, .. } => {
|
||||
if env.seen_recursion_pair(ctx.first, ctx.second) {
|
||||
return fix_fixpoint(env, ctx);
|
||||
}
|
||||
env.add_recursion_pair(ctx.first, ctx.second);
|
||||
|
||||
// suppose that the recursion var is a lambda set
|
||||
unify_pool(env, pool, ctx.first, *structure, ctx.mode)
|
||||
let outcome = unify_pool(env, pool, ctx.first, *structure, ctx.mode);
|
||||
|
||||
env.remove_recursion_pair(ctx.first, ctx.second);
|
||||
outcome
|
||||
}
|
||||
RigidVar(..) | RigidAbleVar(..) => mismatch!("Lambda sets never unify with rigid"),
|
||||
FlexAbleVar(..) => mismatch!("Lambda sets should never have abilities attached to them"),
|
||||
|
@ -2701,6 +2799,11 @@ fn unify_shared_tags_merge_new<M: MetaCollector>(
|
|||
new_ext_var: Variable,
|
||||
recursion_var: Rec,
|
||||
) -> Outcome<M> {
|
||||
let was_fixed = env.was_fixed(ctx.first) || env.was_fixed(ctx.second);
|
||||
if was_fixed {
|
||||
return Default::default();
|
||||
}
|
||||
|
||||
let flat_type = match recursion_var {
|
||||
Rec::None => FlatType::TagUnion(new_tags, new_ext_var),
|
||||
Rec::Left(rec) | Rec::Right(rec) | Rec::Both(rec, _) => {
|
||||
|
@ -2770,8 +2873,16 @@ fn unify_flat_type<M: MetaCollector>(
|
|||
}
|
||||
|
||||
(RecursiveTagUnion(rec1, tags1, ext1), RecursiveTagUnion(rec2, tags2, ext2)) => {
|
||||
debug_assert!(is_recursion_var(env.subs, *rec1));
|
||||
debug_assert!(is_recursion_var(env.subs, *rec2));
|
||||
debug_assert!(
|
||||
is_recursion_var(env.subs, *rec1),
|
||||
"{:?}",
|
||||
env.subs.dbg(*rec1)
|
||||
);
|
||||
debug_assert!(
|
||||
is_recursion_var(env.subs, *rec2),
|
||||
"{:?}",
|
||||
env.subs.dbg(*rec2)
|
||||
);
|
||||
|
||||
let rec = Rec::Both(*rec1, *rec2);
|
||||
let mut outcome = unify_tag_unions(env, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec);
|
||||
|
@ -3240,7 +3351,15 @@ fn unify_recursion<M: MetaCollector>(
|
|||
structure: Variable,
|
||||
other: &Content,
|
||||
) -> Outcome<M> {
|
||||
match other {
|
||||
if !matches!(other, RecursionVar { .. }) {
|
||||
if env.seen_recursion_pair(ctx.first, ctx.second) {
|
||||
return Default::default();
|
||||
}
|
||||
|
||||
env.add_recursion_pair(ctx.first, ctx.second);
|
||||
}
|
||||
|
||||
let outcome = match other {
|
||||
RecursionVar {
|
||||
opt_name: other_opt_name,
|
||||
structure: _other_structure,
|
||||
|
@ -3315,7 +3434,13 @@ fn unify_recursion<M: MetaCollector>(
|
|||
}
|
||||
|
||||
Error => merge(env, ctx, Error),
|
||||
};
|
||||
|
||||
if !matches!(other, RecursionVar { .. }) {
|
||||
env.remove_recursion_pair(ctx.first, ctx.second);
|
||||
}
|
||||
|
||||
outcome
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue