Unify FunctionOrTagUnion with regular tags and functions

This commit is contained in:
tarjei 2021-05-23 23:14:17 +02:00
parent a53c7b5546
commit 0ee15f15ee
4 changed files with 197 additions and 11 deletions

View file

@ -634,11 +634,14 @@ pub fn canonicalize_expr<'a>(
let variant_var = var_store.fresh(); let variant_var = var_store.fresh();
let ext_var = var_store.fresh(); let ext_var = var_store.fresh();
let symbol = env.gen_unique_symbol();
( (
Tag { ZeroArgumentTag {
name: TagName::Global((*tag).into()), name: TagName::Global((*tag).into()),
arguments: vec![], arguments: vec![],
variant_var, variant_var,
closure_name: symbol,
ext_var, ext_var,
}, },
Output::default(), Output::default(),

View file

@ -816,13 +816,6 @@ pub fn constrain_expr(
ext_var, ext_var,
name, name,
arguments, arguments,
}
| ZeroArgumentTag {
variant_var,
ext_var,
name,
arguments,
..
} => { } => {
let mut vars = Vec::with_capacity(arguments.len()); let mut vars = Vec::with_capacity(arguments.len());
let mut types = Vec::with_capacity(arguments.len()); let mut types = Vec::with_capacity(arguments.len());
@ -867,6 +860,58 @@ pub fn constrain_expr(
exists(vars, And(arg_cons)) exists(vars, And(arg_cons))
} }
ZeroArgumentTag {
variant_var,
ext_var,
name,
arguments,
closure_name,
} => {
let mut vars = Vec::with_capacity(arguments.len());
let mut types = Vec::with_capacity(arguments.len());
let mut arg_cons = Vec::with_capacity(arguments.len());
for (var, loc_expr) in arguments {
let arg_con = constrain_expr(
env,
loc_expr.region,
&loc_expr.value,
Expected::NoExpectation(Type::Variable(*var)),
);
arg_cons.push(arg_con);
vars.push(*var);
types.push(Type::Variable(*var));
}
let union_con = Eq(
Type::FunctionOrTagUnion(
name.clone(),
*closure_name,
Box::new(Type::Variable(*ext_var)),
),
expected.clone(),
Category::TagApply {
tag_name: name.clone(),
args_count: arguments.len(),
},
region,
);
let ast_con = Eq(
Type::Variable(*variant_var),
expected,
Category::Storage(std::file!(), std::line!()),
region,
);
vars.push(*variant_var);
vars.push(*ext_var);
arg_cons.push(union_con);
arg_cons.push(ast_con);
exists(vars, And(arg_cons))
}
RunLowLevel { args, ret_var, op } => { RunLowLevel { args, ret_var, op } => {
// This is a modified version of what we do for function calls. // This is a modified version of what we do for function calls.

View file

@ -490,7 +490,7 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String
} }
} }
FunctionOrTagUnion(tag_name, _, _) => { FunctionOrTagUnion(tag_name, _, ext_var) => {
let interns = &env.interns; let interns = &env.interns;
let home = env.home; let home = env.home;
@ -499,6 +499,17 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String
buf.push_str(&tag_name.as_string(&interns, home)); buf.push_str(&tag_name.as_string(&interns, home));
buf.push_str(" ]"); buf.push_str(" ]");
let mut sorted_fields = vec![(tag_name, vec![])];
let ext_content = chase_ext_tag_union(subs, ext_var, &mut sorted_fields);
if let Err((_, content)) = ext_content {
// This is an open tag union, so print the variable
// right after the ']'
//
// e.g. the "*" at the end of `{ x: I64 }*`
// or the "r" at the end of `{ x: I64 }r`
write_content(env, content, subs, buf, parens)
}
} }
RecursiveTagUnion(rec_var, tags, ext_var) => { RecursiveTagUnion(rec_var, tags, ext_var) => {

View file

@ -237,6 +237,10 @@ fn unify_structure(
// unify the structure with this recursive tag union // unify the structure with this recursive tag union
unify_pool(subs, pool, ctx.first, *structure) unify_pool(subs, pool, ctx.first, *structure)
} }
FlatType::FunctionOrTagUnion(_, _, _) => {
// unify the structure with this unrecursive tag union
unify_pool(subs, pool, ctx.first, *structure)
}
_ => todo!("rec structure {:?}", &flat_type), _ => todo!("rec structure {:?}", &flat_type),
}, },
@ -979,10 +983,92 @@ fn unify_flat_type(
} }
} }
(TagUnion(tags, ext), Func(args, closure, ret)) if tags.len() == 1 => { (TagUnion(tags, ext), Func(args, closure, ret)) if tags.len() == 1 => {
unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, true) // unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, true)
panic!()
} }
(Func(args, closure, ret), TagUnion(tags, ext)) if tags.len() == 1 => { (Func(args, closure, ret), TagUnion(tags, ext)) if tags.len() == 1 => {
unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, false) // unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, false)
panic!()
}
(FunctionOrTagUnion(tag_name, _, ext), Func(args, closure, ret)) => {
unify_function_or_tag_union_and_func(
tag_name, args, subs, pool, ctx, ext, ret, closure, true,
)
}
(Func(args, closure, ret), FunctionOrTagUnion(tag_name, _, ext)) => {
unify_function_or_tag_union_and_func(
tag_name, args, subs, pool, ctx, ext, ret, closure, false,
)
}
(FunctionOrTagUnion(tag_name_1, _, ext_1), FunctionOrTagUnion(tag_name_2, _, ext_2)) => {
if tag_name_1 == tag_name_2 {
let problems = unify_pool(subs, pool, *ext_1, *ext_2);
if problems.is_empty() {
let desc = subs.get(ctx.second);
merge(subs, ctx, desc.content)
} else {
problems
}
} else {
let mut tags1 = MutMap::default();
tags1.insert(tag_name_1.clone(), vec![]);
let union1 = gather_tags(subs, tags1, *ext_1);
let mut tags2 = MutMap::default();
tags2.insert(tag_name_2.clone(), vec![]);
let union2 = gather_tags(subs, tags2, *ext_2);
unify_tag_union(subs, pool, ctx, union1, union2, (None, None))
}
}
(TagUnion(tags1, ext1), FunctionOrTagUnion(tag_name, _, ext2)) => {
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let mut tags2 = MutMap::default();
tags2.insert(tag_name.clone(), vec![]);
let union2 = gather_tags(subs, tags2, *ext2);
unify_tag_union(subs, pool, ctx, union1, union2, (None, None))
}
(FunctionOrTagUnion(tag_name, _, ext1), TagUnion(tags2, ext2)) => {
let mut tags1 = MutMap::default();
tags1.insert(tag_name.clone(), vec![]);
let union1 = gather_tags(subs, tags1, *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
unify_tag_union(subs, pool, ctx, union1, union2, (None, None))
}
(RecursiveTagUnion(recursion_var, tags1, ext1), FunctionOrTagUnion(tag_name, _, ext2)) => {
debug_assert!(is_recursion_var(subs, *recursion_var));
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let mut tags2 = MutMap::default();
tags2.insert(tag_name.clone(), vec![]);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
unify_tag_union(
subs,
pool,
ctx,
union1,
union2,
(Some(*recursion_var), None),
)
}
(FunctionOrTagUnion(tag_name, _, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => {
debug_assert!(is_recursion_var(subs, *recursion_var));
let mut tags1 = MutMap::default();
tags1.insert(tag_name.clone(), vec![]);
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
unify_tag_union_not_recursive_recursive(subs, pool, ctx, union1, union2, *recursion_var)
} }
(other1, other2) => mismatch!( (other1, other2) => mismatch!(
"Trying to unify two flat types that are incompatible: {:?} ~ {:?}", "Trying to unify two flat types that are incompatible: {:?} ~ {:?}",
@ -1216,3 +1302,44 @@ fn unify_tag_union_and_func(
) )
} }
} }
#[allow(clippy::too_many_arguments)]
fn unify_function_or_tag_union_and_func(
tag_name: &TagName,
args: &[Variable],
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
ext: &Variable,
ret: &Variable,
closure: &Variable,
left: bool,
) -> Outcome {
use FlatType::*;
let mut new_tags = MutMap::with_capacity_and_hasher(1, default_hasher());
new_tags.insert(tag_name.clone(), args.to_owned());
let content = Structure(TagUnion(new_tags, *ext));
let new_tag_union_var = fresh(subs, pool, ctx, content);
let problems = if left {
unify_pool(subs, pool, new_tag_union_var, *ret)
} else {
unify_pool(subs, pool, *ret, new_tag_union_var)
};
if problems.is_empty() {
let desc = if left {
subs.get(ctx.second)
} else {
subs.get(ctx.first)
};
subs.union(ctx.first, ctx.second, desc);
}
problems
}