support Tag arguments

This commit is contained in:
Folkert 2020-01-05 22:49:08 +01:00
parent 1d5ceeafe4
commit 5ad2823a23
8 changed files with 251 additions and 77 deletions

View file

@ -108,7 +108,7 @@ pub enum Expr {
variant_var: Variable,
ext_var: Variable,
name: Symbol,
arguments: Vec<Located<Expr>>,
arguments: Vec<(Variable, Located<Expr>)>,
},
// Compiles, but will crash if reached
@ -280,6 +280,17 @@ pub fn canonicalize_expr(
// We can't call a runtime error; bail out by propagating it!
return (fn_expr, output);
}
Tag {
variant_var,
ext_var,
name,
..
} => Tag {
variant_var,
ext_var,
name,
arguments: args,
},
_ => {
// This could be something like ((if True then fn1 else fn2) arg1 arg2).
Call(

View file

@ -462,18 +462,44 @@ pub fn constrain_expr(
])
}
Tag {
variant_var: _,
variant_var,
ext_var,
name,
arguments: _,
} => Eq(
Type::TagUnion(
vec![(name.clone(), vec![])],
Box::new(Type::Variable(*ext_var)),
),
expected,
region,
),
arguments,
} => {
let mut vars = Vec::with_capacity(arguments.len());
let mut types = Vec::with_capacity(arguments.len());
let mut arg_cons = Vec::with_capacity(arguments.len());
for (var, loc_expr) in arguments.into_iter() {
let arg_con = constrain_expr(
rigids,
loc_expr.region,
&loc_expr.value,
Expected::NoExpectation(Type::Variable(*var)),
);
arg_cons.push(arg_con);
vars.push(*var);
types.push(Type::Variable(*var));
}
let union_con = Eq(
Type::TagUnion(
vec![(name.clone(), types)],
Box::new(Type::Variable(*ext_var)),
),
expected.clone(),
region,
);
let ast_con = Eq(Type::Variable(*variant_var), expected, region);
vars.push(*variant_var);
arg_cons.push(union_con);
arg_cons.push(ast_con);
exists(vars, And(arg_cons))
}
RuntimeError(_) => True,
}
}

View file

@ -98,12 +98,8 @@ fn find_names_needed(
find_names_needed(ext_var, subs, roots, root_appearances, names_taken);
}
Structure(TagUnion(tags, ext_var)) => {
for arities in tags.values() {
for arity in arities.values() {
for var in arity {
find_names_needed(*var, subs, roots, root_appearances, names_taken);
}
}
for var in tags.values().flatten() {
find_names_needed(*var, subs, roots, root_appearances, names_taken);
}
find_names_needed(ext_var, subs, roots, root_appearances, names_taken);
@ -260,10 +256,8 @@ fn write_flat_type(flat_type: FlatType, subs: &mut Subs, buf: &mut String, paren
// Sort the fields so they always end up in the same order.
let mut sorted_fields = Vec::with_capacity(tags.len());
for (label, arities) in tags {
for (_, vars) in arities {
sorted_fields.push((label.clone(), vars));
}
for (label, vars) in tags {
sorted_fields.push((label.clone(), vars));
}
sorted_fields.sort_by(|(a, _), (b, _)| a.cmp(b));

View file

@ -438,18 +438,7 @@ fn type_to_variable(
tag_argument_vars.push(type_to_variable(subs, rank, pools, aliases, arg_type));
}
let f = |opt_value: Option<ImMap<usize, Vec<Variable>>>| -> Option<ImMap<usize, Vec<Variable>>> {
if let Some(mut current) = opt_value {
current.insert(tag_argument_vars.len(), tag_argument_vars);
Some(current)
} else {
let mut new = ImMap::default();
new.insert(tag_argument_vars.len(), tag_argument_vars);
Some(new)
}
};
tag_vars = tag_vars.alter(f, tag.clone())
tag_vars.insert(tag.clone(), tag_argument_vars);
}
let ext_var = type_to_variable(subs, rank, pools, aliases, ext);
@ -676,14 +665,9 @@ fn adjust_rank_content(
TagUnion(tags, ext_var) => {
let mut rank = adjust_rank(subs, young_mark, visit_mark, group_rank, ext_var);
for arities in tags.values() {
for arity in arities.values() {
for var in arity {
rank = rank.max(adjust_rank(
subs, young_mark, visit_mark, group_rank, *var,
));
}
}
for var in tags.values().flatten() {
rank =
rank.max(adjust_rank(subs, young_mark, visit_mark, group_rank, *var));
}
rank
@ -812,16 +796,12 @@ fn deep_copy_var_help(
TagUnion(tags, ext_var) => {
let mut new_tags = ImMap::default();
for (tag, arities) in tags {
let mut tag_at_arity = ImMap::default();
for (arity, vars) in arities {
let new_vars: Vec<Variable> = vars
.into_iter()
.map(|var| deep_copy_var_help(subs, max_rank, pools, var))
.collect();
tag_at_arity.insert(arity, new_vars);
}
new_tags.insert(tag, tag_at_arity);
for (tag, vars) in tags {
let new_vars: Vec<Variable> = vars
.into_iter()
.map(|var| deep_copy_var_help(subs, max_rank, pools, var))
.collect();
new_tags.insert(tag, new_vars);
}
TagUnion(new_tags, deep_copy_var_help(subs, max_rank, pools, ext_var))

View file

@ -433,7 +433,7 @@ pub enum FlatType {
Record(ImMap<RecordFieldLabel, Variable>, Variable),
// Within a tag union, a tag can occur multiple times, e.g. [ Foo, Foo Int, Foo Bool Int ], but
// only once for every arity, so not [ Foo Int, Foo Bool ]
TagUnion(ImMap<Symbol, ImMap<usize, Vec<Variable>>>, Variable),
TagUnion(ImMap<Symbol, Vec<Variable>>, Variable),
Erroneous(Problem),
EmptyRecord,
EmptyTagUnion,
@ -476,10 +476,8 @@ fn occurs(subs: &mut Subs, seen: &ImSet<Variable>, var: Variable) -> bool {
}
TagUnion(tags, ext_var) => {
occurs(subs, &new_seen, ext_var)
|| tags.values().any(|arities| {
arities.values().any(|vars| {
vars.into_iter().any(|var| occurs(subs, &new_seen, *var))
})
|| tags.values().any(|vars| {
vars.into_iter().any(|var| occurs(subs, &new_seen, *var))
})
}
EmptyRecord | EmptyTagUnion | Erroneous(_) => false,
@ -552,11 +550,9 @@ fn get_var_names(
FlatType::TagUnion(tags, ext_var) => {
let mut taken_names = get_var_names(subs, ext_var, taken_names);
for arities in tags.values() {
for arity in arities.values() {
for arg_var in arity {
taken_names = get_var_names(subs, *arg_var, taken_names)
}
for vars in tags.values() {
for arg_var in vars {
taken_names = get_var_names(subs, *arg_var, taken_names)
}
}
@ -703,7 +699,7 @@ fn flat_type_to_err_type(subs: &mut Subs, state: &mut NameState, flat_type: Flat
}
EmptyRecord => ErrorType::Record(SendMap::default(), TypeExt::Closed),
EmptyTagUnion => ErrorType::TagUnion(Vec::new(), TypeExt::Closed),
EmptyTagUnion => ErrorType::TagUnion(SendMap::default(), TypeExt::Closed),
Record(vars_by_field, ext_var) => {
let mut err_fields = SendMap::default();
@ -730,8 +726,35 @@ fn flat_type_to_err_type(subs: &mut Subs, state: &mut NameState, flat_type: Flat
}
}
TagUnion(_tags, _ext_var) => {
panic!("TODO implement error type for TagUnion");
TagUnion(tags, ext_var) => {
let mut err_tags = SendMap::default();
for (tag, vars) in tags.into_iter() {
let mut err_vars = Vec::with_capacity(vars.len());
for var in vars {
err_vars.push(var_to_err_type(subs, state, var));
}
err_tags.insert(tag, err_vars);
}
match var_to_err_type(subs, state, ext_var).unwrap_alias() {
ErrorType::TagUnion(sub_tags, sub_ext) => {
ErrorType::TagUnion(sub_tags.union(err_tags), sub_ext)
}
ErrorType::FlexVar(var) => {
ErrorType::TagUnion(err_tags, TypeExt::FlexOpen(var))
}
ErrorType::RigidVar(var) => {
ErrorType::TagUnion(err_tags, TypeExt::RigidOpen(var))
}
other =>
panic!("Tried to convert a tag union extension to an error, but the tag union extension had the ErrorType of {:?}", other)
}
}
Erroneous(_) => ErrorType::Error,
@ -779,12 +802,8 @@ fn restore_content(subs: &mut Subs, content: &Content) {
subs.restore(*ext_var);
}
TagUnion(tags, ext_var) => {
for arities in tags.values() {
for arity in arities.values() {
for var in arity {
subs.restore(*var);
}
}
for var in tags.values().flatten() {
subs.restore(*var);
}
subs.restore(*ext_var);

View file

@ -381,7 +381,7 @@ pub enum ErrorType {
FlexVar(Lowercase),
RigidVar(Lowercase),
Record(SendMap<RecordFieldLabel, ErrorType>, TypeExt),
TagUnion(Vec<(Uppercase, Vec<ErrorType>)>, TypeExt),
TagUnion(SendMap<Symbol, Vec<ErrorType>>, TypeExt),
Function(Vec<ErrorType>, Box<ErrorType>),
Alias(
ModuleName,

View file

@ -1,4 +1,5 @@
use crate::can::ident::{Lowercase, ModuleName, Uppercase};
use crate::can::symbol::Symbol;
use crate::collections::ImMap;
use crate::subs::Content::{self, *};
use crate::subs::{Descriptor, FlatType, Mark, OptVariable, Subs, Variable};
@ -19,6 +20,11 @@ struct RecordStructure {
ext: Variable,
}
struct TagUnionStructure {
tags: ImMap<Symbol, Vec<Variable>>,
ext: Variable,
}
pub struct Unified {
pub vars: Pool,
pub mismatches: Vec<Problem>,
@ -244,6 +250,115 @@ fn unify_shared_fields(
}
}
fn unify_tag_union(
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
rec1: TagUnionStructure,
rec2: TagUnionStructure,
) -> Outcome {
let tags1 = rec1.tags;
let tags2 = rec2.tags;
let shared_tags = tags1
.clone()
.intersection_with(tags2.clone(), |one, two| (one, two));
// NOTE: don't use `difference` here, in contrast to Haskell, im-rc `difference` is symmetric
let unique_tags1 = tags1.clone().relative_complement(tags2.clone());
let unique_tags2 = tags2.relative_complement(tags1);
if unique_tags1.is_empty() {
if unique_tags2.is_empty() {
let ext_problems = unify_pool(subs, pool, rec1.ext, rec2.ext);
let mut tag_problems =
unify_shared_tags(subs, pool, ctx, shared_tags, ImMap::default(), rec1.ext);
tag_problems.extend(ext_problems);
tag_problems
} else {
let flat_type = FlatType::TagUnion(unique_tags2, rec2.ext);
let sub_record = fresh(subs, pool, ctx, Structure(flat_type));
let ext_problems = unify_pool(subs, pool, rec1.ext, sub_record);
let mut tag_problems =
unify_shared_tags(subs, pool, ctx, shared_tags, ImMap::default(), sub_record);
tag_problems.extend(ext_problems);
tag_problems
}
} else if unique_tags2.is_empty() {
let flat_type = FlatType::TagUnion(unique_tags1, rec1.ext);
let sub_record = fresh(subs, pool, ctx, Structure(flat_type));
let ext_problems = unify_pool(subs, pool, sub_record, rec2.ext);
let mut tag_problems =
unify_shared_tags(subs, pool, ctx, shared_tags, ImMap::default(), sub_record);
tag_problems.extend(ext_problems);
tag_problems
} else {
let other_tags = unique_tags1.clone().union(unique_tags2.clone());
let ext = fresh(subs, pool, ctx, Content::FlexVar(None));
let flat_type1 = FlatType::TagUnion(unique_tags1, rec1.ext);
let flat_type2 = FlatType::TagUnion(unique_tags2, rec2.ext);
let sub1 = fresh(subs, pool, ctx, Structure(flat_type1));
let sub2 = fresh(subs, pool, ctx, Structure(flat_type2));
let rec1_problems = unify_pool(subs, pool, rec1.ext, sub2);
let rec2_problems = unify_pool(subs, pool, sub1, rec2.ext);
let mut tag_problems = unify_shared_tags(subs, pool, ctx, shared_tags, other_tags, ext);
tag_problems.reserve(rec1_problems.len() + rec2_problems.len());
tag_problems.extend(rec1_problems);
tag_problems.extend(rec2_problems);
tag_problems
}
}
fn unify_shared_tags(
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
shared_tags: ImMap<Symbol, (Vec<Variable>, Vec<Variable>)>,
other_tags: ImMap<Symbol, Vec<Variable>>,
ext: Variable,
) -> Outcome {
let mut matching_tags = ImMap::default();
let num_shared_tags = shared_tags.len();
for (name, (actual_vars, expected_vars)) in shared_tags {
let mut matching_vars = Vec::with_capacity(actual_vars.len());
let actual_len = actual_vars.len();
let expected_len = expected_vars.len();
for (actual, expected) in actual_vars.into_iter().zip(expected_vars.into_iter()) {
let problems = unify_pool(subs, pool, actual, expected);
if problems.is_empty() {
matching_vars.push(actual);
}
}
// only do this check after unification so the error message has more info
if actual_len == expected_len && actual_len == matching_tags.len() {
matching_tags.insert(name, matching_vars);
}
}
if num_shared_tags == matching_tags.len() {
let flat_type = FlatType::TagUnion(matching_tags.union(other_tags), ext);
merge(subs, ctx, Structure(flat_type))
} else {
mismatch()
}
}
#[inline(always)]
fn unify_flat_type(
subs: &mut Subs,
@ -283,11 +398,10 @@ fn unify_flat_type(
}
(TagUnion(tags1, ext1), TagUnion(tags2, ext2)) => {
// let rec1 = gather_fields(subs, fields1.clone(), *ext1);
// let rec2 = gather_fields(subs, fields2.clone(), *ext2);
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
// unify_record(subs, pool, ctx, rec1, rec2)
panic!("TODO");
unify_tag_union(subs, pool, ctx, union1, union2)
}
(
@ -413,6 +527,25 @@ fn gather_fields(
}
}
fn gather_tags(
subs: &mut Subs,
tags: ImMap<Symbol, Vec<Variable>>,
var: Variable,
) -> TagUnionStructure {
use crate::subs::FlatType::*;
match subs.get(var).content {
Structure(TagUnion(sub_tags, sub_ext)) => gather_tags(subs, tags.union(sub_tags), sub_ext),
Alias(_, _, _, var) => {
// TODO according to elm/compiler: "TODO may be dropping useful alias info here"
gather_tags(subs, tags, var)
}
_ => TagUnionStructure { tags, ext: var },
}
}
fn merge(subs: &mut Subs, ctx: &Context, content: Content) -> Outcome {
let rank = ctx.first_desc.rank.min(ctx.second_desc.rank);
let desc = Descriptor {

View file

@ -1028,7 +1028,7 @@ mod test_infer {
r#"\Foo -> 42
"#
),
"[ Foo ] -> Int",
"[ Foo ]* -> Int",
);
}
@ -1042,7 +1042,18 @@ mod test_infer {
False -> 0
"#
),
"[ True, False ] -> Int",
"[ False, True ]* -> Int",
);
}
#[test]
fn tag_application() {
infer_eq(
indoc!(
r#"Foo "happy" 2020
"#
),
"[ Foo Str Int ]*",
);
}
}