recursive tag unions

This commit is contained in:
Folkert 2020-08-31 14:29:09 +02:00
parent ba186bfe09
commit f9cf4ea371
8 changed files with 286 additions and 75 deletions

View file

@ -7,7 +7,7 @@ use crate::llvm::build_list::{
}; };
use crate::llvm::compare::{build_eq, build_neq}; use crate::llvm::compare::{build_eq, build_neq};
use crate::llvm::convert::{ use crate::llvm::convert::{
basic_type_from_layout, collection, get_fn_type, get_ptr_type, ptr_int, basic_type_from_layout, block_of_memory, collection, get_fn_type, get_ptr_type, ptr_int,
}; };
use bumpalo::collections::Vec; use bumpalo::collections::Vec;
use bumpalo::Bump; use bumpalo::Bump;
@ -478,7 +478,6 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(*union_size > 1); debug_assert!(*union_size > 1);
let ptr_size = env.ptr_bytes; let ptr_size = env.ptr_bytes;
dbg!(&tag_layout);
let mut filler = tag_layout.stack_size(ptr_size); let mut filler = tag_layout.stack_size(ptr_size);
let ctx = env.context; let ctx = env.context;
@ -513,7 +512,6 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
ptr, ptr,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ctx.i64_type().ptr_type(AddressSpace::Generic).into(),
); );
dbg!(&ptr);
field_vals.push(ptr); field_vals.push(ptr);
} else { } else {
field_vals.push(val); field_vals.push(val);
@ -638,14 +636,15 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
.expect("desired field did not decode"); .expect("desired field did not decode");
if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) { if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) {
let struct_layout = Layout::Struct(field_layouts);
let desired_type = block_of_memory(env.context, &struct_layout, env.ptr_bytes);
// the value is a pointer to the actual value; load that value! // the value is a pointer to the actual value; load that value!
use inkwell::types::BasicType;
let ptr = cast_basic_basic( let ptr = cast_basic_basic(
builder, builder,
result, result,
struct_value desired_type.ptr_type(AddressSpace::Generic).into(),
.get_type()
.ptr_type(AddressSpace::Generic)
.into(),
); );
builder.build_load(ptr.into_pointer_value(), "load_recursive_field") builder.build_load(ptr.into_pointer_value(), "load_recursive_field")
} else { } else {
@ -819,7 +818,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
Expr::AccessAtIndex { field_layouts, .. } => { Expr::AccessAtIndex { field_layouts, .. } => {
let layout = Layout::Struct(field_layouts); let layout = Layout::Struct(field_layouts);
basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes) block_of_memory(env.context, &layout, env.ptr_bytes)
} }
_ => unreachable!( _ => unreachable!(
"a recursive pointer can only be loaded from a recursive tag union" "a recursive pointer can only be loaded from a recursive tag union"
@ -1115,7 +1114,9 @@ fn decrement_refcount_layout<'a, 'ctx, 'env>(
} }
} }
} }
RecursiveUnion(_) => todo!("TODO implement decrement layout of recursive tag union"), RecursiveUnion(_) => {
println!("TODO implement decrement layout of recursive tag union");
}
RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"), RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"),
Union(tags) => { Union(tags) => {
debug_assert!(!tags.is_empty()); debug_assert!(!tags.is_empty());

View file

@ -107,37 +107,7 @@ pub fn basic_type_from_layout<'ctx>(
.struct_type(field_types.into_bump_slice(), false) .struct_type(field_types.into_bump_slice(), false)
.as_basic_type_enum() .as_basic_type_enum()
} }
RecursiveUnion(_) | Union(_) => { RecursiveUnion(_) | Union(_) => block_of_memory(context, layout, ptr_bytes),
// TODO make this dynamic
let ptr_size = std::mem::size_of::<i64>();
let union_size = layout.stack_size(ptr_size as u32);
// The memory layout of Union is a bit tricky.
// We have tags with different memory layouts, that are part of the same type.
// For llvm, all tags must have the same memory layout.
//
// So, we convert all tags to a layout of bytes of some size.
// It turns out that encoding to i64 for as many elements as possible is
// a nice optimization, the remainder is encoded as bytes.
let num_i64 = union_size / 8;
let num_i8 = union_size % 8;
let i64_array_type = context.i64_type().array_type(num_i64).as_basic_type_enum();
if num_i8 == 0 {
// the object fits perfectly in some number of i64's
// (i.e. the size is a multiple of 8 bytes)
context.struct_type(&[i64_array_type], false).into()
} else {
// there are some trailing bytes at the end
let i8_array_type = context.i8_type().array_type(num_i8).as_basic_type_enum();
context
.struct_type(&[i64_array_type, i8_array_type], false)
.into()
}
}
RecursivePointer => { RecursivePointer => {
// TODO make this dynamic // TODO make this dynamic
context context
@ -165,6 +135,41 @@ pub fn basic_type_from_layout<'ctx>(
} }
} }
pub fn block_of_memory<'ctx>(
context: &'ctx Context,
layout: &Layout<'_>,
ptr_bytes: u32,
) -> BasicTypeEnum<'ctx> {
// TODO make this dynamic
let union_size = layout.stack_size(ptr_bytes as u32);
// The memory layout of Union is a bit tricky.
// We have tags with different memory layouts, that are part of the same type.
// For llvm, all tags must have the same memory layout.
//
// So, we convert all tags to a layout of bytes of some size.
// It turns out that encoding to i64 for as many elements as possible is
// a nice optimization, the remainder is encoded as bytes.
let num_i64 = union_size / 8;
let num_i8 = union_size % 8;
let i64_array_type = context.i64_type().array_type(num_i64).as_basic_type_enum();
if num_i8 == 0 {
// the object fits perfectly in some number of i64's
// (i.e. the size is a multiple of 8 bytes)
context.struct_type(&[i64_array_type], false).into()
} else {
// there are some trailing bytes at the end
let i8_array_type = context.i8_type().array_type(num_i8).as_basic_type_enum();
context
.struct_type(&[i64_array_type, i8_array_type], false)
.into()
}
}
/// Two usize values. Could be a wrapper for a List or a Str. /// Two usize values. Could be a wrapper for a List or a Str.
/// ///
/// It would be nicer if we could store this as a tuple containing one usize /// It would be nicer if we could store this as a tuple containing one usize

View file

@ -472,7 +472,7 @@ mod gen_primitives {
} }
#[test] #[test]
fn peano() { fn peano1() {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
@ -482,8 +482,8 @@ mod gen_primitives {
three = S (S (S Z)) three = S (S (S Z))
when three is when three is
Z -> 2
S _ -> 1 S _ -> 1
Z -> 0
"# "#
), ),
1, 1,
@ -511,4 +511,110 @@ mod gen_primitives {
i64 i64
); );
} }
#[test]
fn linked_list_len_0() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
nil : LinkedList Int
nil = Nil
length : LinkedList a -> Int
length = \list ->
when list is
Nil -> 0
Cons _ rest -> 1 + length rest
length nil
"#
),
0,
i64
);
}
#[test]
fn linked_list_len_3() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
three : LinkedList Int
three = Cons 3 (Cons 2 (Cons 1 Nil))
length : LinkedList a -> Int
length = \list ->
when list is
Nil -> 0
Cons _ rest -> 1 + length rest
length three
"#
),
3,
i64
);
}
#[test]
fn linked_list_sum() {
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
three : LinkedList Int
three = Cons 3 (Cons 2 (Cons 1 Nil))
sum : LinkedList a -> Int
sum = \list ->
when list is
Nil -> 0
Cons x rest -> x + sum rest
sum three
"#
),
3 + 2 + 1,
i64
);
}
#[test]
#[ignore]
fn linked_list_map() {
// `f` is not actually a function, so the call to it fails currently
assert_evals_to!(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
three : LinkedList Int
three = Cons 3 (Cons 2 (Cons 1 Nil))
sum : LinkedList a -> Int
sum = \list ->
when list is
Nil -> 0
Cons x rest -> x + sum rest
map : (a -> b), LinkedList a -> LinkedList b
map = \f, list ->
when list is
Nil -> Nil
Cons x rest -> Cons (f x) (map f rest)
sum (map (\_ -> 1) three)
"#
),
3,
i64
);
}
} }

View file

@ -184,7 +184,7 @@ pub fn helper_without_uniqueness<'a>(
); );
// Uncomment this to see the module's un-optimized LLVM instruction output: // Uncomment this to see the module's un-optimized LLVM instruction output:
env.module.print_to_stderr(); // env.module.print_to_stderr();
if main_fn.verify(true) { if main_fn.verify(true) {
function_pass.run_on(&main_fn); function_pass.run_on(&main_fn);

View file

@ -116,7 +116,7 @@ impl<'a> Layout<'a> {
Ok(Layout::RecursivePointer) Ok(Layout::RecursivePointer)
} else { } else {
let content = env.subs.get_without_compacting(var).content; let content = env.subs.get_without_compacting(var).content;
println!("{:?} {:?}", var, &content); // println!("{:?} {:?}", var, &content);
Self::new_help(env, content) Self::new_help(env, content)
} }
} }
@ -355,7 +355,7 @@ fn layout_from_flat_type<'a>(
// Num.Num should only ever have 1 argument, e.g. Num.Num Int.Integer // Num.Num should only ever have 1 argument, e.g. Num.Num Int.Integer
debug_assert_eq!(args.len(), 1); debug_assert_eq!(args.len(), 1);
let var = args.iter().next().unwrap(); let var = args.get(0).unwrap();
let content = subs.get_without_compacting(*var).content; let content = subs.get_without_compacting(*var).content;
layout_from_num_content(content) layout_from_num_content(content)
@ -483,7 +483,7 @@ fn layout_from_flat_type<'a>(
let mut tag_layout = Vec::with_capacity_in(variables.len() + 1, arena); let mut tag_layout = Vec::with_capacity_in(variables.len() + 1, arena);
// store the discriminant // store the discriminant
tag_layout.push(Layout::Builtin(Builtin::Int8)); tag_layout.push(Layout::Builtin(Builtin::Int64));
for var in variables { for var in variables {
// TODO does this cause problems with mutually recursive unions? // TODO does this cause problems with mutually recursive unions?
@ -584,6 +584,9 @@ pub fn union_sorted_tags<'a>(arena: &'a Bump, var: Variable, subs: &Subs) -> Uni
fn get_recursion_var(subs: &Subs, var: Variable) -> Option<Variable> { fn get_recursion_var(subs: &Subs, var: Variable) -> Option<Variable> {
match subs.get_without_compacting(var).content { match subs.get_without_compacting(var).content {
Content::Structure(FlatType::RecursiveTagUnion(rec_var, _, _)) => Some(rec_var), Content::Structure(FlatType::RecursiveTagUnion(rec_var, _, _)) => Some(rec_var),
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) => {
get_recursion_var(subs, args[1])
}
Content::Alias(_, _, actual) => get_recursion_var(subs, actual), Content::Alias(_, _, actual) => get_recursion_var(subs, actual),
_ => None, _ => None,
} }

View file

@ -1559,6 +1559,47 @@ mod test_mono {
) )
} }
#[test]
fn peano1() {
compiles_to_ir(
indoc!(
r#"
Peano : [ S Peano, Z ]
three : Peano
three = S (S (S Z))
when three is
Z -> 0
S _ -> 1
"#
),
indoc!(
r#"
let Test.9 = 0i64;
let Test.11 = 0i64;
let Test.13 = 0i64;
let Test.15 = 1i64;
let Test.14 = Z Test.15;
let Test.12 = S Test.13 Test.14;
let Test.10 = S Test.11 Test.12;
let Test.1 = S Test.9 Test.10;
let Test.5 = true;
let Test.7 = Index 0 Test.1;
let Test.6 = 1i64;
let Test.8 = lowlevel Eq Test.6 Test.7;
let Test.4 = lowlevel And Test.8 Test.5;
if Test.4 then
let Test.2 = 0i64;
ret Test.2;
else
let Test.3 = 1i64;
ret Test.3;
"#
),
)
}
#[test] #[test]
fn peano2() { fn peano2() {
compiles_to_ir( compiles_to_ir(

View file

@ -3052,6 +3052,9 @@ mod test_reporting {
#[test] #[test]
fn two_different_cons() { fn two_different_cons() {
// TODO investigate what is happening here;
// while it makes some kind of sense to print the recursion var as infinite,
// it's not very helpful in practice.
report_problem_as( report_problem_as(
indoc!( indoc!(
r#" r#"
@ -3075,11 +3078,11 @@ mod test_reporting {
This `Cons` global tag application has the type: This `Cons` global tag application has the type:
[ Cons {} [ Cons Str [ Cons {} a, Nil ] as a, Nil ], Nil ] [ Cons {} [ Cons Str [ Cons {} , Nil ] as , Nil ], Nil ]
But the type annotation on `x` says it should be: But the type annotation on `x` says it should be:
[ Cons {} a, Nil ] as a [ Cons {} , Nil ] as
"# "#
), ),
) )

View file

@ -565,43 +565,84 @@ fn unify_shared_tags(
// and so on until the whole non-recursive tag union can be unified with it. // and so on until the whole non-recursive tag union can be unified with it.
let mut problems = Vec::new(); let mut problems = Vec::new();
let attr_wrapped = match (subs.get(expected).content, subs.get(actual).content) {
(
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, expected_args)),
Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, actual_args)),
) => Some((
expected_args[0],
expected_args[1],
actual_args[0],
actual_args[1],
)),
_ => None,
};
if let Some(rvar) = recursion_var { if let Some(rvar) = recursion_var {
if expected == rvar { match attr_wrapped {
problems.extend(unify_pool(subs, pool, actual, ctx.second)); None => {
println!("A"); if expected == rvar {
} else if is_structure(actual, subs) { if actual == rvar {
// the recursion variable is hidden behind some structure (commonly an Attr problems.extend(unify_pool(subs, pool, expected, actual));
// with uniqueness inference). Thus we must expand the recursive tag union to } else {
// unify if with the non-recursive one. Thus: problems.extend(unify_pool(subs, pool, actual, ctx.second));
// replace the rvar with ctx.second (the whole recursive tag union) in expected // this unification is required for layout generation,
subs.explicit_substitute(rvar, ctx.second, expected); // but causes worse error messages
problems.extend(unify_pool(subs, pool, expected, actual));
}
} else if is_structure(actual, subs) {
// the recursion variable is hidden behind some structure (commonly an Attr
// with uniqueness inference). Thus we must expand the recursive tag union to
// unify if with the non-recursive one. Thus:
// but, by the `is_structure` condition above, only if we're unifying with a structure! // replace the rvar with ctx.second (the whole recursive tag union) in expected
// when `actual` is just a flex/rigid variable, the substitution would expand a subs.explicit_substitute(rvar, ctx.second, expected);
// recursive tag union infinitely!
problems.extend(unify_pool(subs, pool, actual, expected)); // but, by the `is_structure` condition above, only if we're unifying with a structure!
println!("B"); // when `actual` is just a flex/rigid variable, the substitution would expand a
} else { // recursive tag union infinitely!
// unification with a non-structure is trivial
problems.extend(unify_pool(subs, pool, actual, expected)); problems.extend(unify_pool(subs, pool, actual, expected));
println!("C"); } else {
// unification with a non-structure is trivial
problems.extend(unify_pool(subs, pool, actual, expected));
}
}
Some((_expected_uvar, inner_expected, _actual_uvar, inner_actual)) => {
if inner_expected == rvar {
if inner_actual == rvar {
problems.extend(unify_pool(subs, pool, actual, expected));
} else {
problems.extend(unify_pool(subs, pool, inner_actual, ctx.second));
problems.extend(unify_pool(subs, pool, expected, actual));
}
} else if is_structure(inner_actual, subs) {
// the recursion variable is hidden behind some structure (commonly an Attr
// with uniqueness inference). Thus we must expand the recursive tag union to
// unify if with the non-recursive one. Thus:
// replace the rvar with ctx.second (the whole recursive tag union) in expected
subs.explicit_substitute(rvar, ctx.second, inner_expected);
// but, by the `is_structure` condition above, only if we're unifying with a structure!
// when `actual` is just a flex/rigid variable, the substitution would expand a
// recursive tag union infinitely!
problems.extend(unify_pool(subs, pool, actual, expected));
} else {
// unification with a non-structure is trivial
problems.extend(unify_pool(subs, pool, actual, expected));
}
}
} }
} else { } else {
// we always unify NonRecursive with Recursive, so this should never happen // we always unify NonRecursive with Recursive, so this should never happen
debug_assert_ne!(Some(actual), recursion_var); debug_assert_ne!(Some(actual), recursion_var);
problems.extend(unify_pool(subs, pool, actual, expected)); problems.extend(unify_pool(subs, pool, actual, expected));
println!("D");
}; };
// TODO this changes some error messages
// but is important for the inference of recursive types
if problems.is_empty() {
problems.extend(unify_pool(subs, pool, expected, actual));
}
if problems.is_empty() { if problems.is_empty() {
// debug_assert_eq!(subs.get_root_key(actual), subs.get_root_key(expected)); // debug_assert_eq!(subs.get_root_key(actual), subs.get_root_key(expected));
matching_vars.push(actual); matching_vars.push(actual);
@ -695,8 +736,19 @@ fn unify_flat_type(
unify_tag_union(subs, pool, ctx, union1, union2, (None, None)) unify_tag_union(subs, pool, ctx, union1, union2, (None, None))
} }
(RecursiveTagUnion(_, _, _), TagUnion(_, _)) => { (RecursiveTagUnion(recursion_var, tags1, ext1), TagUnion(tags2, ext2)) => {
unreachable!("unify of recursive with non-recursive tag union should not occur"); // unreachable!("unify of recursive with non-recursive tag union should not occur");
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
unify_tag_union(
subs,
pool,
ctx,
union1,
union2,
(None, Some(*recursion_var)),
)
} }
(TagUnion(tags1, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => { (TagUnion(tags1, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => {