support Tag arguments

This commit is contained in:
Folkert 2020-01-05 22:49:08 +01:00
parent 1d5ceeafe4
commit 5ad2823a23
8 changed files with 251 additions and 77 deletions

View file

@ -108,7 +108,7 @@ pub enum Expr {
variant_var: Variable, variant_var: Variable,
ext_var: Variable, ext_var: Variable,
name: Symbol, name: Symbol,
arguments: Vec<Located<Expr>>, arguments: Vec<(Variable, Located<Expr>)>,
}, },
// Compiles, but will crash if reached // Compiles, but will crash if reached
@ -280,6 +280,17 @@ pub fn canonicalize_expr(
// We can't call a runtime error; bail out by propagating it! // We can't call a runtime error; bail out by propagating it!
return (fn_expr, output); return (fn_expr, output);
} }
Tag {
variant_var,
ext_var,
name,
..
} => Tag {
variant_var,
ext_var,
name,
arguments: args,
},
_ => { _ => {
// This could be something like ((if True then fn1 else fn2) arg1 arg2). // This could be something like ((if True then fn1 else fn2) arg1 arg2).
Call( Call(

View file

@ -462,18 +462,44 @@ pub fn constrain_expr(
]) ])
} }
Tag { Tag {
variant_var: _, variant_var,
ext_var, ext_var,
name, name,
arguments: _, arguments,
} => Eq( } => {
Type::TagUnion( let mut vars = Vec::with_capacity(arguments.len());
vec![(name.clone(), vec![])], let mut types = Vec::with_capacity(arguments.len());
Box::new(Type::Variable(*ext_var)), let mut arg_cons = Vec::with_capacity(arguments.len());
),
expected, for (var, loc_expr) in arguments.into_iter() {
region, let arg_con = constrain_expr(
), rigids,
loc_expr.region,
&loc_expr.value,
Expected::NoExpectation(Type::Variable(*var)),
);
arg_cons.push(arg_con);
vars.push(*var);
types.push(Type::Variable(*var));
}
let union_con = Eq(
Type::TagUnion(
vec![(name.clone(), types)],
Box::new(Type::Variable(*ext_var)),
),
expected.clone(),
region,
);
let ast_con = Eq(Type::Variable(*variant_var), expected, region);
vars.push(*variant_var);
arg_cons.push(union_con);
arg_cons.push(ast_con);
exists(vars, And(arg_cons))
}
RuntimeError(_) => True, RuntimeError(_) => True,
} }
} }

View file

@ -98,12 +98,8 @@ fn find_names_needed(
find_names_needed(ext_var, subs, roots, root_appearances, names_taken); find_names_needed(ext_var, subs, roots, root_appearances, names_taken);
} }
Structure(TagUnion(tags, ext_var)) => { Structure(TagUnion(tags, ext_var)) => {
for arities in tags.values() { for var in tags.values().flatten() {
for arity in arities.values() { find_names_needed(*var, subs, roots, root_appearances, names_taken);
for var in arity {
find_names_needed(*var, subs, roots, root_appearances, names_taken);
}
}
} }
find_names_needed(ext_var, subs, roots, root_appearances, names_taken); find_names_needed(ext_var, subs, roots, root_appearances, names_taken);
@ -260,10 +256,8 @@ fn write_flat_type(flat_type: FlatType, subs: &mut Subs, buf: &mut String, paren
// Sort the fields so they always end up in the same order. // Sort the fields so they always end up in the same order.
let mut sorted_fields = Vec::with_capacity(tags.len()); let mut sorted_fields = Vec::with_capacity(tags.len());
for (label, arities) in tags { for (label, vars) in tags {
for (_, vars) in arities { sorted_fields.push((label.clone(), vars));
sorted_fields.push((label.clone(), vars));
}
} }
sorted_fields.sort_by(|(a, _), (b, _)| a.cmp(b)); sorted_fields.sort_by(|(a, _), (b, _)| a.cmp(b));

View file

@ -438,18 +438,7 @@ fn type_to_variable(
tag_argument_vars.push(type_to_variable(subs, rank, pools, aliases, arg_type)); tag_argument_vars.push(type_to_variable(subs, rank, pools, aliases, arg_type));
} }
let f = |opt_value: Option<ImMap<usize, Vec<Variable>>>| -> Option<ImMap<usize, Vec<Variable>>> { tag_vars.insert(tag.clone(), tag_argument_vars);
if let Some(mut current) = opt_value {
current.insert(tag_argument_vars.len(), tag_argument_vars);
Some(current)
} else {
let mut new = ImMap::default();
new.insert(tag_argument_vars.len(), tag_argument_vars);
Some(new)
}
};
tag_vars = tag_vars.alter(f, tag.clone())
} }
let ext_var = type_to_variable(subs, rank, pools, aliases, ext); let ext_var = type_to_variable(subs, rank, pools, aliases, ext);
@ -676,14 +665,9 @@ fn adjust_rank_content(
TagUnion(tags, ext_var) => { TagUnion(tags, ext_var) => {
let mut rank = adjust_rank(subs, young_mark, visit_mark, group_rank, ext_var); let mut rank = adjust_rank(subs, young_mark, visit_mark, group_rank, ext_var);
for arities in tags.values() { for var in tags.values().flatten() {
for arity in arities.values() { rank =
for var in arity { rank.max(adjust_rank(subs, young_mark, visit_mark, group_rank, *var));
rank = rank.max(adjust_rank(
subs, young_mark, visit_mark, group_rank, *var,
));
}
}
} }
rank rank
@ -812,16 +796,12 @@ fn deep_copy_var_help(
TagUnion(tags, ext_var) => { TagUnion(tags, ext_var) => {
let mut new_tags = ImMap::default(); let mut new_tags = ImMap::default();
for (tag, arities) in tags { for (tag, vars) in tags {
let mut tag_at_arity = ImMap::default(); let new_vars: Vec<Variable> = vars
for (arity, vars) in arities { .into_iter()
let new_vars: Vec<Variable> = vars .map(|var| deep_copy_var_help(subs, max_rank, pools, var))
.into_iter() .collect();
.map(|var| deep_copy_var_help(subs, max_rank, pools, var)) new_tags.insert(tag, new_vars);
.collect();
tag_at_arity.insert(arity, new_vars);
}
new_tags.insert(tag, tag_at_arity);
} }
TagUnion(new_tags, deep_copy_var_help(subs, max_rank, pools, ext_var)) TagUnion(new_tags, deep_copy_var_help(subs, max_rank, pools, ext_var))

View file

@ -433,7 +433,7 @@ pub enum FlatType {
Record(ImMap<RecordFieldLabel, Variable>, Variable), Record(ImMap<RecordFieldLabel, Variable>, Variable),
// Within a tag union, a tag can occur multiple times, e.g. [ Foo, Foo Int, Foo Bool Int ], but // Within a tag union, a tag can occur multiple times, e.g. [ Foo, Foo Int, Foo Bool Int ], but
// only once for every arity, so not [ Foo Int, Foo Bool ] // only once for every arity, so not [ Foo Int, Foo Bool ]
TagUnion(ImMap<Symbol, ImMap<usize, Vec<Variable>>>, Variable), TagUnion(ImMap<Symbol, Vec<Variable>>, Variable),
Erroneous(Problem), Erroneous(Problem),
EmptyRecord, EmptyRecord,
EmptyTagUnion, EmptyTagUnion,
@ -476,10 +476,8 @@ fn occurs(subs: &mut Subs, seen: &ImSet<Variable>, var: Variable) -> bool {
} }
TagUnion(tags, ext_var) => { TagUnion(tags, ext_var) => {
occurs(subs, &new_seen, ext_var) occurs(subs, &new_seen, ext_var)
|| tags.values().any(|arities| { || tags.values().any(|vars| {
arities.values().any(|vars| { vars.into_iter().any(|var| occurs(subs, &new_seen, *var))
vars.into_iter().any(|var| occurs(subs, &new_seen, *var))
})
}) })
} }
EmptyRecord | EmptyTagUnion | Erroneous(_) => false, EmptyRecord | EmptyTagUnion | Erroneous(_) => false,
@ -552,11 +550,9 @@ fn get_var_names(
FlatType::TagUnion(tags, ext_var) => { FlatType::TagUnion(tags, ext_var) => {
let mut taken_names = get_var_names(subs, ext_var, taken_names); let mut taken_names = get_var_names(subs, ext_var, taken_names);
for arities in tags.values() { for vars in tags.values() {
for arity in arities.values() { for arg_var in vars {
for arg_var in arity { taken_names = get_var_names(subs, *arg_var, taken_names)
taken_names = get_var_names(subs, *arg_var, taken_names)
}
} }
} }
@ -703,7 +699,7 @@ fn flat_type_to_err_type(subs: &mut Subs, state: &mut NameState, flat_type: Flat
} }
EmptyRecord => ErrorType::Record(SendMap::default(), TypeExt::Closed), EmptyRecord => ErrorType::Record(SendMap::default(), TypeExt::Closed),
EmptyTagUnion => ErrorType::TagUnion(Vec::new(), TypeExt::Closed), EmptyTagUnion => ErrorType::TagUnion(SendMap::default(), TypeExt::Closed),
Record(vars_by_field, ext_var) => { Record(vars_by_field, ext_var) => {
let mut err_fields = SendMap::default(); let mut err_fields = SendMap::default();
@ -730,8 +726,35 @@ fn flat_type_to_err_type(subs: &mut Subs, state: &mut NameState, flat_type: Flat
} }
} }
TagUnion(_tags, _ext_var) => { TagUnion(tags, ext_var) => {
panic!("TODO implement error type for TagUnion"); let mut err_tags = SendMap::default();
for (tag, vars) in tags.into_iter() {
let mut err_vars = Vec::with_capacity(vars.len());
for var in vars {
err_vars.push(var_to_err_type(subs, state, var));
}
err_tags.insert(tag, err_vars);
}
match var_to_err_type(subs, state, ext_var).unwrap_alias() {
ErrorType::TagUnion(sub_tags, sub_ext) => {
ErrorType::TagUnion(sub_tags.union(err_tags), sub_ext)
}
ErrorType::FlexVar(var) => {
ErrorType::TagUnion(err_tags, TypeExt::FlexOpen(var))
}
ErrorType::RigidVar(var) => {
ErrorType::TagUnion(err_tags, TypeExt::RigidOpen(var))
}
other =>
panic!("Tried to convert a tag union extension to an error, but the tag union extension had the ErrorType of {:?}", other)
}
} }
Erroneous(_) => ErrorType::Error, Erroneous(_) => ErrorType::Error,
@ -779,12 +802,8 @@ fn restore_content(subs: &mut Subs, content: &Content) {
subs.restore(*ext_var); subs.restore(*ext_var);
} }
TagUnion(tags, ext_var) => { TagUnion(tags, ext_var) => {
for arities in tags.values() { for var in tags.values().flatten() {
for arity in arities.values() { subs.restore(*var);
for var in arity {
subs.restore(*var);
}
}
} }
subs.restore(*ext_var); subs.restore(*ext_var);

View file

@ -381,7 +381,7 @@ pub enum ErrorType {
FlexVar(Lowercase), FlexVar(Lowercase),
RigidVar(Lowercase), RigidVar(Lowercase),
Record(SendMap<RecordFieldLabel, ErrorType>, TypeExt), Record(SendMap<RecordFieldLabel, ErrorType>, TypeExt),
TagUnion(Vec<(Uppercase, Vec<ErrorType>)>, TypeExt), TagUnion(SendMap<Symbol, Vec<ErrorType>>, TypeExt),
Function(Vec<ErrorType>, Box<ErrorType>), Function(Vec<ErrorType>, Box<ErrorType>),
Alias( Alias(
ModuleName, ModuleName,

View file

@ -1,4 +1,5 @@
use crate::can::ident::{Lowercase, ModuleName, Uppercase}; use crate::can::ident::{Lowercase, ModuleName, Uppercase};
use crate::can::symbol::Symbol;
use crate::collections::ImMap; use crate::collections::ImMap;
use crate::subs::Content::{self, *}; use crate::subs::Content::{self, *};
use crate::subs::{Descriptor, FlatType, Mark, OptVariable, Subs, Variable}; use crate::subs::{Descriptor, FlatType, Mark, OptVariable, Subs, Variable};
@ -19,6 +20,11 @@ struct RecordStructure {
ext: Variable, ext: Variable,
} }
struct TagUnionStructure {
tags: ImMap<Symbol, Vec<Variable>>,
ext: Variable,
}
pub struct Unified { pub struct Unified {
pub vars: Pool, pub vars: Pool,
pub mismatches: Vec<Problem>, pub mismatches: Vec<Problem>,
@ -244,6 +250,115 @@ fn unify_shared_fields(
} }
} }
fn unify_tag_union(
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
rec1: TagUnionStructure,
rec2: TagUnionStructure,
) -> Outcome {
let tags1 = rec1.tags;
let tags2 = rec2.tags;
let shared_tags = tags1
.clone()
.intersection_with(tags2.clone(), |one, two| (one, two));
// NOTE: don't use `difference` here, in contrast to Haskell, im-rc `difference` is symmetric
let unique_tags1 = tags1.clone().relative_complement(tags2.clone());
let unique_tags2 = tags2.relative_complement(tags1);
if unique_tags1.is_empty() {
if unique_tags2.is_empty() {
let ext_problems = unify_pool(subs, pool, rec1.ext, rec2.ext);
let mut tag_problems =
unify_shared_tags(subs, pool, ctx, shared_tags, ImMap::default(), rec1.ext);
tag_problems.extend(ext_problems);
tag_problems
} else {
let flat_type = FlatType::TagUnion(unique_tags2, rec2.ext);
let sub_record = fresh(subs, pool, ctx, Structure(flat_type));
let ext_problems = unify_pool(subs, pool, rec1.ext, sub_record);
let mut tag_problems =
unify_shared_tags(subs, pool, ctx, shared_tags, ImMap::default(), sub_record);
tag_problems.extend(ext_problems);
tag_problems
}
} else if unique_tags2.is_empty() {
let flat_type = FlatType::TagUnion(unique_tags1, rec1.ext);
let sub_record = fresh(subs, pool, ctx, Structure(flat_type));
let ext_problems = unify_pool(subs, pool, sub_record, rec2.ext);
let mut tag_problems =
unify_shared_tags(subs, pool, ctx, shared_tags, ImMap::default(), sub_record);
tag_problems.extend(ext_problems);
tag_problems
} else {
let other_tags = unique_tags1.clone().union(unique_tags2.clone());
let ext = fresh(subs, pool, ctx, Content::FlexVar(None));
let flat_type1 = FlatType::TagUnion(unique_tags1, rec1.ext);
let flat_type2 = FlatType::TagUnion(unique_tags2, rec2.ext);
let sub1 = fresh(subs, pool, ctx, Structure(flat_type1));
let sub2 = fresh(subs, pool, ctx, Structure(flat_type2));
let rec1_problems = unify_pool(subs, pool, rec1.ext, sub2);
let rec2_problems = unify_pool(subs, pool, sub1, rec2.ext);
let mut tag_problems = unify_shared_tags(subs, pool, ctx, shared_tags, other_tags, ext);
tag_problems.reserve(rec1_problems.len() + rec2_problems.len());
tag_problems.extend(rec1_problems);
tag_problems.extend(rec2_problems);
tag_problems
}
}
fn unify_shared_tags(
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
shared_tags: ImMap<Symbol, (Vec<Variable>, Vec<Variable>)>,
other_tags: ImMap<Symbol, Vec<Variable>>,
ext: Variable,
) -> Outcome {
let mut matching_tags = ImMap::default();
let num_shared_tags = shared_tags.len();
for (name, (actual_vars, expected_vars)) in shared_tags {
let mut matching_vars = Vec::with_capacity(actual_vars.len());
let actual_len = actual_vars.len();
let expected_len = expected_vars.len();
for (actual, expected) in actual_vars.into_iter().zip(expected_vars.into_iter()) {
let problems = unify_pool(subs, pool, actual, expected);
if problems.is_empty() {
matching_vars.push(actual);
}
}
// only do this check after unification so the error message has more info
if actual_len == expected_len && actual_len == matching_tags.len() {
matching_tags.insert(name, matching_vars);
}
}
if num_shared_tags == matching_tags.len() {
let flat_type = FlatType::TagUnion(matching_tags.union(other_tags), ext);
merge(subs, ctx, Structure(flat_type))
} else {
mismatch()
}
}
#[inline(always)] #[inline(always)]
fn unify_flat_type( fn unify_flat_type(
subs: &mut Subs, subs: &mut Subs,
@ -283,11 +398,10 @@ fn unify_flat_type(
} }
(TagUnion(tags1, ext1), TagUnion(tags2, ext2)) => { (TagUnion(tags1, ext1), TagUnion(tags2, ext2)) => {
// let rec1 = gather_fields(subs, fields1.clone(), *ext1); let union1 = gather_tags(subs, tags1.clone(), *ext1);
// let rec2 = gather_fields(subs, fields2.clone(), *ext2); let union2 = gather_tags(subs, tags2.clone(), *ext2);
// unify_record(subs, pool, ctx, rec1, rec2) unify_tag_union(subs, pool, ctx, union1, union2)
panic!("TODO");
} }
( (
@ -413,6 +527,25 @@ fn gather_fields(
} }
} }
fn gather_tags(
subs: &mut Subs,
tags: ImMap<Symbol, Vec<Variable>>,
var: Variable,
) -> TagUnionStructure {
use crate::subs::FlatType::*;
match subs.get(var).content {
Structure(TagUnion(sub_tags, sub_ext)) => gather_tags(subs, tags.union(sub_tags), sub_ext),
Alias(_, _, _, var) => {
// TODO according to elm/compiler: "TODO may be dropping useful alias info here"
gather_tags(subs, tags, var)
}
_ => TagUnionStructure { tags, ext: var },
}
}
fn merge(subs: &mut Subs, ctx: &Context, content: Content) -> Outcome { fn merge(subs: &mut Subs, ctx: &Context, content: Content) -> Outcome {
let rank = ctx.first_desc.rank.min(ctx.second_desc.rank); let rank = ctx.first_desc.rank.min(ctx.second_desc.rank);
let desc = Descriptor { let desc = Descriptor {

View file

@ -1028,7 +1028,7 @@ mod test_infer {
r#"\Foo -> 42 r#"\Foo -> 42
"# "#
), ),
"[ Foo ] -> Int", "[ Foo ]* -> Int",
); );
} }
@ -1042,7 +1042,18 @@ mod test_infer {
False -> 0 False -> 0
"# "#
), ),
"[ True, False ] -> Int", "[ False, True ]* -> Int",
);
}
#[test]
fn tag_application() {
infer_eq(
indoc!(
r#"Foo "happy" 2020
"#
),
"[ Foo Str Int ]*",
); );
} }
} }