Lift lambda sets as their own type

This prepares for unspecialized lambda set in the type system in
general.
This commit is contained in:
Ayaz Hafiz 2022-05-31 14:36:44 -05:00
parent 40b43ea98d
commit c2a2ce690c
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
13 changed files with 708 additions and 410 deletions

View file

@ -8,8 +8,8 @@ use roc_module::symbol::Symbol;
use roc_types::num::NumericRange;
use roc_types::subs::Content::{self, *};
use roc_types::subs::{
AliasVariables, Descriptor, ErrorTypeContext, FlatType, GetSubsSlice, Mark, OptVariable,
RecordFields, Subs, SubsIndex, SubsSlice, UnionTags, Variable, VariableSubsSlice,
AliasVariables, Descriptor, ErrorTypeContext, FlatType, GetSubsSlice, LambdaSet, Mark,
OptVariable, RecordFields, Subs, SubsIndex, SubsSlice, UnionTags, Variable, VariableSubsSlice,
};
use roc_types::types::{AliasKind, DoesNotImplementAbility, ErrorType, Mismatch, RecordField};
@ -22,7 +22,7 @@ macro_rules! mismatch {
line!(),
column!()
);
})
});
Outcome {
mismatches: vec![Mismatch::TypeMismatch],
@ -395,6 +395,7 @@ fn unify_context(subs: &mut Subs, pool: &mut Pool, ctx: Context) -> Outcome {
Alias(symbol, args, real_var, AliasKind::Opaque) => {
unify_opaque(subs, pool, &ctx, *symbol, *args, *real_var)
}
LambdaSet(lset) => unify_lambda_set(subs, pool, &ctx, *lset, &ctx.second_desc.content),
&RangedNumber(typ, range_vars) => unify_ranged_number(subs, pool, &ctx, typ, range_vars),
Error => {
// Error propagates. Whatever we're comparing it to doesn't matter!
@ -438,6 +439,7 @@ fn unify_ranged_number(
}
// TODO: We should probably check that "range_vars" and "other_range_vars" intersect
}
LambdaSet(..) => mismatch!(),
Error => merge(subs, ctx, Error),
};
@ -582,6 +584,7 @@ fn unify_alias(
outcome
}
}
LambdaSet(..) => mismatch!("cannot unify alias {:?} with lambda set {:?}: lambda sets should never be directly behind an alias!", ctx.first, other_content),
Error => merge(subs, ctx, Error),
}
}
@ -766,6 +769,13 @@ fn unify_structure(
)
}
},
LambdaSet(..) => {
mismatch!(
"Cannot unify structure {:?} with lambda set {:?}",
&flat_type,
other
)
}
RangedNumber(other_real_var, other_range_vars) => {
let outcome = unify_pool(subs, pool, ctx.first, *other_real_var, ctx.mode);
if outcome.mismatches.is_empty() {
@ -778,6 +788,109 @@ fn unify_structure(
}
}
#[inline(always)]
fn unify_lambda_set(
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
lambda_set: LambdaSet,
other: &Content,
) -> Outcome {
match other {
FlexVar(_) => merge(subs, ctx, Content::LambdaSet(lambda_set)),
Content::LambdaSet(other_lambda_set) => {
unify_lambda_set_help(subs, pool, ctx, lambda_set, *other_lambda_set)
}
RigidVar(..) | RigidAbleVar(..) => mismatch!("Lambda sets never unify with rigid"),
FlexAbleVar(..) => mismatch!("Lambda sets should never have abilities attached to them"),
Structure(..) => mismatch!("Lambda set cannot unify with non-lambda set structure"),
RangedNumber(..) => mismatch!("Lambda sets are never numbers"),
RecursionVar { .. } => mismatch!("Lambda set not expected to be recursive!"),
Alias(..) => mismatch!("Lambda set can never be directly under an alias!"),
Error => merge(subs, ctx, Error),
}
}
#[allow(clippy::too_many_arguments)]
fn unify_lambda_set_help(
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
lset1: self::LambdaSet,
lset2: self::LambdaSet,
) -> Outcome {
// LambdaSets unify like TagUnions, but can grow unbounded regardless of the extension
// variable.
let LambdaSet { solved: solved1 } = lset1;
let LambdaSet { solved: solved2 } = lset2;
let (separate_solved, _, _) = separate_union_tags(
subs,
solved1,
Variable::EMPTY_TAG_UNION,
solved2,
Variable::EMPTY_TAG_UNION,
);
let Separate {
only_in_1,
only_in_2,
in_both,
} = separate_solved;
let num_shared = in_both.len();
let mut joined_lambdas = vec![];
for (tag_name, (vars1, vars2)) in in_both {
let mut joined_vars = vec![];
if vars1.len() != vars2.len() {
continue; // this is a type mismatch; not adding the tag will trigger it below.
}
let num_vars = vars1.len();
for (var1, var2) in (vars1.into_iter()).zip(vars2.into_iter()) {
let (var1, var2) = (subs[var1], subs[var2]);
let outcome = unify_pool(subs, pool, var1, var2, ctx.mode);
if outcome.mismatches.is_empty() {
// otherwise this is a type mismatch; not adding the variable will trigger it below.
joined_vars.push(var1);
}
}
if joined_vars.len() == num_vars {
joined_lambdas.push((tag_name, joined_vars));
}
}
if joined_lambdas.len() == num_shared {
mismatch!(
"Problem with lambda sets: there should be {:?} matching lambda, but only found {:?}",
num_shared,
&joined_lambdas
)
} else {
let all_lambdas = merge_sorted(
joined_lambdas,
(only_in_1.into_iter())
.chain(only_in_2.into_iter())
.map(|(name, subs_slice)| {
let vec = subs.get_subs_slice(subs_slice).to_vec();
(name, vec)
}),
);
let new_solved = UnionTags::insert_into_subs(subs, all_lambdas);
let new_lambda_set = Content::LambdaSet(LambdaSet { solved: new_solved });
merge(subs, ctx, new_lambda_set)
}
}
/// Ensures that a non-recursive tag union, when unified with a recursion var to become a recursive
/// tag union, properly contains a recursion variable that recurses on itself.
//
@ -1184,7 +1297,7 @@ enum Rec {
}
#[allow(clippy::too_many_arguments)]
fn unify_tag_union_new(
fn unify_tag_unions(
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
@ -1573,7 +1686,7 @@ fn unify_flat_type(
}
(TagUnion(tags1, ext1), TagUnion(tags2, ext2)) => {
unify_tag_union_new(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, Rec::None)
unify_tag_unions(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, Rec::None)
}
(RecursiveTagUnion(recursion_var, tags1, ext1), TagUnion(tags2, ext2)) => {
@ -1582,7 +1695,7 @@ fn unify_flat_type(
let rec = Rec::Left(*recursion_var);
unify_tag_union_new(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec)
unify_tag_unions(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec)
}
(TagUnion(tags1, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => {
@ -1590,7 +1703,7 @@ fn unify_flat_type(
let rec = Rec::Right(*recursion_var);
unify_tag_union_new(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec)
unify_tag_unions(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec)
}
(RecursiveTagUnion(rec1, tags1, ext1), RecursiveTagUnion(rec2, tags2, ext2)) => {
@ -1598,8 +1711,7 @@ fn unify_flat_type(
debug_assert!(is_recursion_var(subs, *rec2));
let rec = Rec::Both(*rec1, *rec2);
let mut outcome =
unify_tag_union_new(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec);
let mut outcome = unify_tag_unions(subs, pool, ctx, *tags1, *ext1, *tags2, *ext2, rec);
outcome.union(unify_pool(subs, pool, *rec1, *rec2, ctx.mode));
outcome
@ -1680,18 +1792,18 @@ fn unify_flat_type(
let tags1 = UnionTags::from_tag_name_index(*tag_name_1);
let tags2 = UnionTags::from_tag_name_index(*tag_name_2);
unify_tag_union_new(subs, pool, ctx, tags1, *ext1, tags2, *ext2, Rec::None)
unify_tag_unions(subs, pool, ctx, tags1, *ext1, tags2, *ext2, Rec::None)
}
}
(TagUnion(tags1, ext1), FunctionOrTagUnion(tag_name, _, ext2)) => {
let tags2 = UnionTags::from_tag_name_index(*tag_name);
unify_tag_union_new(subs, pool, ctx, *tags1, *ext1, tags2, *ext2, Rec::None)
unify_tag_unions(subs, pool, ctx, *tags1, *ext1, tags2, *ext2, Rec::None)
}
(FunctionOrTagUnion(tag_name, _, ext1), TagUnion(tags2, ext2)) => {
let tags1 = UnionTags::from_tag_name_index(*tag_name);
unify_tag_union_new(subs, pool, ctx, tags1, *ext1, *tags2, *ext2, Rec::None)
unify_tag_unions(subs, pool, ctx, tags1, *ext1, *tags2, *ext2, Rec::None)
}
(RecursiveTagUnion(recursion_var, tags1, ext1), FunctionOrTagUnion(tag_name, _, ext2)) => {
@ -1701,7 +1813,7 @@ fn unify_flat_type(
let tags2 = UnionTags::from_tag_name_index(*tag_name);
let rec = Rec::Left(*recursion_var);
unify_tag_union_new(subs, pool, ctx, *tags1, *ext1, tags2, *ext2, rec)
unify_tag_unions(subs, pool, ctx, *tags1, *ext1, tags2, *ext2, rec)
}
(FunctionOrTagUnion(tag_name, _, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => {
@ -1710,7 +1822,7 @@ fn unify_flat_type(
let tags1 = UnionTags::from_tag_name_index(*tag_name);
let rec = Rec::Right(*recursion_var);
unify_tag_union_new(subs, pool, ctx, tags1, *ext1, *tags2, *ext2, rec)
unify_tag_unions(subs, pool, ctx, tags1, *ext1, *tags2, *ext2, rec)
}
// these have underscores because they're unused in --release builds
@ -1787,7 +1899,12 @@ fn unify_rigid(
}
}
RigidVar(_) | RecursionVar { .. } | Structure(_) | Alias(_, _, _, _) | RangedNumber(..)
RigidVar(_)
| RecursionVar { .. }
| Structure(_)
| Alias(_, _, _, _)
| RangedNumber(..)
| LambdaSet(..)
if ctx.mode.contains(Mode::RIGID_AS_FLEX) =>
{
// Usually rigids can only unify with flex, but the mode indicates we are treating
@ -1824,7 +1941,8 @@ fn unify_rigid(
| RecursionVar { .. }
| Structure(_)
| Alias(..)
| RangedNumber(..) => {
| RangedNumber(..)
| LambdaSet(..) => {
// Type mismatch! Rigid can only unify with flex, even if the
// rigid names are the same.
mismatch!("Rigid {:?} with {:?}", ctx.first, &other)
@ -1888,7 +2006,8 @@ fn unify_flex(
| RecursionVar { .. }
| Structure(_)
| Alias(_, _, _, _)
| RangedNumber(..) => {
| RangedNumber(..)
| LambdaSet(..) => {
// TODO special-case boolean here
// In all other cases, if left is flex, defer to right.
merge(subs, ctx, *other)
@ -1967,6 +2086,10 @@ fn unify_recursion(
&other
),
LambdaSet(..) => {
mismatch!("RecursionVar {:?} with LambdaSet {:?}", ctx.first, &other)
}
Error => merge(subs, ctx, Error),
}
}