recursive tag unions

This commit is contained in:
Folkert 2020-08-31 14:29:09 +02:00
parent ba186bfe09
commit f9cf4ea371
8 changed files with 286 additions and 75 deletions

View file

@ -7,7 +7,7 @@ use crate::llvm::build_list::{
};
use crate::llvm::compare::{build_eq, build_neq};
use crate::llvm::convert::{
basic_type_from_layout, collection, get_fn_type, get_ptr_type, ptr_int,
basic_type_from_layout, block_of_memory, collection, get_fn_type, get_ptr_type, ptr_int,
};
use bumpalo::collections::Vec;
use bumpalo::Bump;
@ -478,7 +478,6 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(*union_size > 1);
let ptr_size = env.ptr_bytes;
dbg!(&tag_layout);
let mut filler = tag_layout.stack_size(ptr_size);
let ctx = env.context;
@ -513,7 +512,6 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
ptr,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(),
);
dbg!(&ptr);
field_vals.push(ptr);
} else {
field_vals.push(val);
@ -638,14 +636,15 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
.expect("desired field did not decode");
if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) {
let struct_layout = Layout::Struct(field_layouts);
let desired_type = block_of_memory(env.context, &struct_layout, env.ptr_bytes);
// the value is a pointer to the actual value; load that value!
use inkwell::types::BasicType;
let ptr = cast_basic_basic(
builder,
result,
struct_value
.get_type()
.ptr_type(AddressSpace::Generic)
.into(),
desired_type.ptr_type(AddressSpace::Generic).into(),
);
builder.build_load(ptr.into_pointer_value(), "load_recursive_field")
} else {
@ -819,7 +818,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
Expr::AccessAtIndex { field_layouts, .. } => {
let layout = Layout::Struct(field_layouts);
basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes)
block_of_memory(env.context, &layout, env.ptr_bytes)
}
_ => unreachable!(
"a recursive pointer can only be loaded from a recursive tag union"
@ -1115,7 +1114,9 @@ fn decrement_refcount_layout<'a, 'ctx, 'env>(
}
}
}
RecursiveUnion(_) => todo!("TODO implement decrement layout of recursive tag union"),
RecursiveUnion(_) => {
println!("TODO implement decrement layout of recursive tag union");
}
RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"),
Union(tags) => {
debug_assert!(!tags.is_empty());

View file

@ -107,10 +107,41 @@ pub fn basic_type_from_layout<'ctx>(
.struct_type(field_types.into_bump_slice(), false)
.as_basic_type_enum()
}
RecursiveUnion(_) | Union(_) => {
RecursiveUnion(_) | Union(_) => block_of_memory(context, layout, ptr_bytes),
RecursivePointer => {
// TODO make this dynamic
let ptr_size = std::mem::size_of::<i64>();
let union_size = layout.stack_size(ptr_size as u32);
context
.i64_type()
.ptr_type(AddressSpace::Generic)
.as_basic_type_enum()
}
Builtin(builtin) => match builtin {
Int128 => context.i128_type().as_basic_type_enum(),
Int64 => context.i64_type().as_basic_type_enum(),
Int32 => context.i32_type().as_basic_type_enum(),
Int16 => context.i16_type().as_basic_type_enum(),
Int8 => context.i8_type().as_basic_type_enum(),
Int1 => context.bool_type().as_basic_type_enum(),
Float128 => context.f128_type().as_basic_type_enum(),
Float64 => context.f64_type().as_basic_type_enum(),
Float32 => context.f32_type().as_basic_type_enum(),
Float16 => context.f16_type().as_basic_type_enum(),
Map(_, _) | EmptyMap => panic!("TODO layout_to_basic_type for Builtin::Map"),
Set(_) | EmptySet => panic!("TODO layout_to_basic_type for Builtin::Set"),
List(_, _) | Str | EmptyStr => collection(context, ptr_bytes).into(),
EmptyList => BasicTypeEnum::StructType(collection(context, ptr_bytes)),
},
}
}
pub fn block_of_memory<'ctx>(
context: &'ctx Context,
layout: &Layout<'_>,
ptr_bytes: u32,
) -> BasicTypeEnum<'ctx> {
// TODO make this dynamic
let union_size = layout.stack_size(ptr_bytes as u32);
// The memory layout of Union is a bit tricky.
// We have tags with different memory layouts, that are part of the same type.
@ -138,32 +169,6 @@ pub fn basic_type_from_layout<'ctx>(
.into()
}
}
RecursivePointer => {
// TODO make this dynamic
context
.i64_type()
.ptr_type(AddressSpace::Generic)
.as_basic_type_enum()
}
Builtin(builtin) => match builtin {
Int128 => context.i128_type().as_basic_type_enum(),
Int64 => context.i64_type().as_basic_type_enum(),
Int32 => context.i32_type().as_basic_type_enum(),
Int16 => context.i16_type().as_basic_type_enum(),
Int8 => context.i8_type().as_basic_type_enum(),
Int1 => context.bool_type().as_basic_type_enum(),
Float128 => context.f128_type().as_basic_type_enum(),
Float64 => context.f64_type().as_basic_type_enum(),
Float32 => context.f32_type().as_basic_type_enum(),
Float16 => context.f16_type().as_basic_type_enum(),
Map(_, _) | EmptyMap => panic!("TODO layout_to_basic_type for Builtin::Map"),
Set(_) | EmptySet => panic!("TODO layout_to_basic_type for Builtin::Set"),
List(_, _) | Str | EmptyStr => collection(context, ptr_bytes).into(),
EmptyList => BasicTypeEnum::StructType(collection(context, ptr_bytes)),
},
}
}
/// Two usize values. Could be a wrapper for a List or a Str.
///

View file

@ -472,7 +472,7 @@ mod gen_primitives {
}
#[test]
fn peano() {
fn peano1() {
assert_evals_to!(
indoc!(
r#"
@ -482,8 +482,8 @@ mod gen_primitives {
three = S (S (S Z))
when three is
Z -> 2
S _ -> 1
Z -> 0
"#
),
1,
@ -511,4 +511,110 @@ mod gen_primitives {
i64
);
}
#[test]
fn linked_list_len_0() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
nil : LinkedList Int
nil = Nil
length : LinkedList a -> Int
length = \list ->
when list is
Nil -> 0
Cons _ rest -> 1 + length rest
length nil
"#
),
0,
i64
);
}
#[test]
fn linked_list_len_3() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
three : LinkedList Int
three = Cons 3 (Cons 2 (Cons 1 Nil))
length : LinkedList a -> Int
length = \list ->
when list is
Nil -> 0
Cons _ rest -> 1 + length rest
length three
"#
),
3,
i64
);
}
#[test]
fn linked_list_sum() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
three : LinkedList Int
three = Cons 3 (Cons 2 (Cons 1 Nil))
sum : LinkedList a -> Int
sum = \list ->
when list is
Nil -> 0
Cons x rest -> x + sum rest
sum three
"#
),
3 + 2 + 1,
i64
);
}
#[test]
#[ignore]
fn linked_list_map() {
// `f` is not actually a function, so the call to it fails currently
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
three : LinkedList Int
three = Cons 3 (Cons 2 (Cons 1 Nil))
sum : LinkedList a -> Int
sum = \list ->
when list is
Nil -> 0
Cons x rest -> x + sum rest
map : (a -> b), LinkedList a -> LinkedList b
map = \f, list ->
when list is
Nil -> Nil
Cons x rest -> Cons (f x) (map f rest)
sum (map (\_ -> 1) three)
"#
),
3,
i64
);
}
}

View file

@ -184,7 +184,7 @@ pub fn helper_without_uniqueness<'a>(
);
// Uncomment this to see the module's un-optimized LLVM instruction output:
env.module.print_to_stderr();
// env.module.print_to_stderr();
if main_fn.verify(true) {
function_pass.run_on(&main_fn);

View file

@ -116,7 +116,7 @@ impl<'a> Layout<'a> {
Ok(Layout::RecursivePointer)
} else {
let content = env.subs.get_without_compacting(var).content;
println!("{:?} {:?}", var, &content);
// println!("{:?} {:?}", var, &content);
Self::new_help(env, content)
}
}
@ -355,7 +355,7 @@ fn layout_from_flat_type<'a>(
// Num.Num should only ever have 1 argument, e.g. Num.Num Int.Integer
debug_assert_eq!(args.len(), 1);
let var = args.iter().next().unwrap();
let var = args.get(0).unwrap();
let content = subs.get_without_compacting(*var).content;
layout_from_num_content(content)
@ -483,7 +483,7 @@ fn layout_from_flat_type<'a>(
let mut tag_layout = Vec::with_capacity_in(variables.len() + 1, arena);
// store the discriminant
tag_layout.push(Layout::Builtin(Builtin::Int8));
tag_layout.push(Layout::Builtin(Builtin::Int64));
for var in variables {
// TODO does this cause problems with mutually recursive unions?
@ -584,6 +584,9 @@ pub fn union_sorted_tags<'a>(arena: &'a Bump, var: Variable, subs: &Subs) -> Uni
fn get_recursion_var(subs: &Subs, var: Variable) -> Option<Variable> {
match subs.get_without_compacting(var).content {
Content::Structure(FlatType::RecursiveTagUnion(rec_var, _, _)) => Some(rec_var),
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) => {
get_recursion_var(subs, args[1])
}
Content::Alias(_, _, actual) => get_recursion_var(subs, actual),
_ => None,
}

View file

@ -1559,6 +1559,47 @@ mod test_mono {
)
}
#[test]
fn peano1() {
compiles_to_ir(
indoc!(
r#"
Peano : [ S Peano, Z ]
three : Peano
three = S (S (S Z))
when three is
Z -> 0
S _ -> 1
"#
),
indoc!(
r#"
let Test.9 = 0i64;
let Test.11 = 0i64;
let Test.13 = 0i64;
let Test.15 = 1i64;
let Test.14 = Z Test.15;
let Test.12 = S Test.13 Test.14;
let Test.10 = S Test.11 Test.12;
let Test.1 = S Test.9 Test.10;
let Test.5 = true;
let Test.7 = Index 0 Test.1;
let Test.6 = 1i64;
let Test.8 = lowlevel Eq Test.6 Test.7;
let Test.4 = lowlevel And Test.8 Test.5;
if Test.4 then
let Test.2 = 0i64;
ret Test.2;
else
let Test.3 = 1i64;
ret Test.3;
"#
),
)
}
#[test]
fn peano2() {
compiles_to_ir(

View file

@ -3052,6 +3052,9 @@ mod test_reporting {
#[test]
fn two_different_cons() {
// TODO investigate what is happening here;
// while it makes some kind of sense to print the recursion var as infinite,
// it's not very helpful in practice.
report_problem_as(
indoc!(
r#"
@ -3075,11 +3078,11 @@ mod test_reporting {
This `Cons` global tag application has the type:
[ Cons {} [ Cons Str [ Cons {} a, Nil ] as a, Nil ], Nil ]
[ Cons {} [ Cons Str [ Cons {} , Nil ] as , Nil ], Nil ]
But the type annotation on `x` says it should be:
[ Cons {} a, Nil ] as a
[ Cons {} , Nil ] as
"#
),
)

View file

@ -565,10 +565,32 @@ fn unify_shared_tags(
// and so on until the whole non-recursive tag union can be unified with it.
let mut problems = Vec::new();
let attr_wrapped = match (subs.get(expected).content, subs.get(actual).content) {
(
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, expected_args)),
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, actual_args)),
) => Some((
expected_args[0],
expected_args[1],
actual_args[0],
actual_args[1],
)),
_ => None,
};
if let Some(rvar) = recursion_var {
match attr_wrapped {
None => {
if expected == rvar {
if actual == rvar {
problems.extend(unify_pool(subs, pool, expected, actual));
} else {
problems.extend(unify_pool(subs, pool, actual, ctx.second));
println!("A");
// this unification is required for layout generation,
// but causes worse error messages
problems.extend(unify_pool(subs, pool, expected, actual));
}
} else if is_structure(actual, subs) {
// the recursion variable is hidden behind some structure (commonly an Attr
// with uniqueness inference). Thus we must expand the recursive tag union to
@ -582,26 +604,45 @@ fn unify_shared_tags(
// recursive tag union infinitely!
problems.extend(unify_pool(subs, pool, actual, expected));
println!("B");
} else {
// unification with a non-structure is trivial
problems.extend(unify_pool(subs, pool, actual, expected));
println!("C");
}
}
Some((_expected_uvar, inner_expected, _actual_uvar, inner_actual)) => {
if inner_expected == rvar {
if inner_actual == rvar {
problems.extend(unify_pool(subs, pool, actual, expected));
} else {
problems.extend(unify_pool(subs, pool, inner_actual, ctx.second));
problems.extend(unify_pool(subs, pool, expected, actual));
}
} else if is_structure(inner_actual, subs) {
// the recursion variable is hidden behind some structure (commonly an Attr
// with uniqueness inference). Thus we must expand the recursive tag union to
// unify if with the non-recursive one. Thus:
// replace the rvar with ctx.second (the whole recursive tag union) in expected
subs.explicit_substitute(rvar, ctx.second, inner_expected);
// but, by the `is_structure` condition above, only if we're unifying with a structure!
// when `actual` is just a flex/rigid variable, the substitution would expand a
// recursive tag union infinitely!
problems.extend(unify_pool(subs, pool, actual, expected));
} else {
// unification with a non-structure is trivial
problems.extend(unify_pool(subs, pool, actual, expected));
}
}
}
} else {
// we always unify NonRecursive with Recursive, so this should never happen
debug_assert_ne!(Some(actual), recursion_var);
problems.extend(unify_pool(subs, pool, actual, expected));
println!("D");
};
// TODO this changes some error messages
// but is important for the inference of recursive types
if problems.is_empty() {
problems.extend(unify_pool(subs, pool, expected, actual));
}
if problems.is_empty() {
// debug_assert_eq!(subs.get_root_key(actual), subs.get_root_key(expected));
matching_vars.push(actual);
@ -695,8 +736,19 @@ fn unify_flat_type(
unify_tag_union(subs, pool, ctx, union1, union2, (None, None))
}
(RecursiveTagUnion(_, _, _), TagUnion(_, _)) => {
unreachable!("unify of recursive with non-recursive tag union should not occur");
(RecursiveTagUnion(recursion_var, tags1, ext1), TagUnion(tags2, ext2)) => {
// unreachable!("unify of recursive with non-recursive tag union should not occur");
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
unify_tag_union(
subs,
pool,
ctx,
union1,
union2,
(None, Some(*recursion_var)),
)
}
(TagUnion(tags1, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => {