diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index dc7a1ab487..074e2e2522 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -7,7 +7,7 @@ use crate::llvm::build_list::{ }; use crate::llvm::compare::{build_eq, build_neq}; use crate::llvm::convert::{ - basic_type_from_layout, collection, get_fn_type, get_ptr_type, ptr_int, + basic_type_from_layout, block_of_memory, collection, get_fn_type, get_ptr_type, ptr_int, }; use bumpalo::collections::Vec; use bumpalo::Bump; @@ -478,7 +478,6 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( debug_assert!(*union_size > 1); let ptr_size = env.ptr_bytes; - dbg!(&tag_layout); let mut filler = tag_layout.stack_size(ptr_size); let ctx = env.context; @@ -513,7 +512,6 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( ptr, ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ); - dbg!(&ptr); field_vals.push(ptr); } else { field_vals.push(val); @@ -638,14 +636,15 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( .expect("desired field did not decode"); if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) { + let struct_layout = Layout::Struct(field_layouts); + let desired_type = block_of_memory(env.context, &struct_layout, env.ptr_bytes); + // the value is a pointer to the actual value; load that value! + use inkwell::types::BasicType; let ptr = cast_basic_basic( builder, result, - struct_value - .get_type() - .ptr_type(AddressSpace::Generic) - .into(), + desired_type.ptr_type(AddressSpace::Generic).into(), ); builder.build_load(ptr.into_pointer_value(), "load_recursive_field") } else { @@ -819,7 +818,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( Expr::AccessAtIndex { field_layouts, .. } => { let layout = Layout::Struct(field_layouts); - basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes) + block_of_memory(env.context, &layout, env.ptr_bytes) } _ => unreachable!( "a recursive pointer can only be loaded from a recursive tag union" @@ -1115,7 +1114,9 @@ fn decrement_refcount_layout<'a, 'ctx, 'env>( } } } - RecursiveUnion(_) => todo!("TODO implement decrement layout of recursive tag union"), + RecursiveUnion(_) => { + println!("TODO implement decrement layout of recursive tag union"); + } RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"), Union(tags) => { debug_assert!(!tags.is_empty()); diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index ccd68b3964..d415fb0d41 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -107,37 +107,7 @@ pub fn basic_type_from_layout<'ctx>( .struct_type(field_types.into_bump_slice(), false) .as_basic_type_enum() } - RecursiveUnion(_) | Union(_) => { - // TODO make this dynamic - let ptr_size = std::mem::size_of::(); - let union_size = layout.stack_size(ptr_size as u32); - - // The memory layout of Union is a bit tricky. - // We have tags with different memory layouts, that are part of the same type. - // For llvm, all tags must have the same memory layout. - // - // So, we convert all tags to a layout of bytes of some size. - // It turns out that encoding to i64 for as many elements as possible is - // a nice optimization, the remainder is encoded as bytes. - - let num_i64 = union_size / 8; - let num_i8 = union_size % 8; - - let i64_array_type = context.i64_type().array_type(num_i64).as_basic_type_enum(); - - if num_i8 == 0 { - // the object fits perfectly in some number of i64's - // (i.e. the size is a multiple of 8 bytes) - context.struct_type(&[i64_array_type], false).into() - } else { - // there are some trailing bytes at the end - let i8_array_type = context.i8_type().array_type(num_i8).as_basic_type_enum(); - - context - .struct_type(&[i64_array_type, i8_array_type], false) - .into() - } - } + RecursiveUnion(_) | Union(_) => block_of_memory(context, layout, ptr_bytes), RecursivePointer => { // TODO make this dynamic context @@ -165,6 +135,41 @@ pub fn basic_type_from_layout<'ctx>( } } +pub fn block_of_memory<'ctx>( + context: &'ctx Context, + layout: &Layout<'_>, + ptr_bytes: u32, +) -> BasicTypeEnum<'ctx> { + // TODO make this dynamic + let union_size = layout.stack_size(ptr_bytes as u32); + + // The memory layout of Union is a bit tricky. + // We have tags with different memory layouts, that are part of the same type. + // For llvm, all tags must have the same memory layout. + // + // So, we convert all tags to a layout of bytes of some size. + // It turns out that encoding to i64 for as many elements as possible is + // a nice optimization, the remainder is encoded as bytes. + + let num_i64 = union_size / 8; + let num_i8 = union_size % 8; + + let i64_array_type = context.i64_type().array_type(num_i64).as_basic_type_enum(); + + if num_i8 == 0 { + // the object fits perfectly in some number of i64's + // (i.e. the size is a multiple of 8 bytes) + context.struct_type(&[i64_array_type], false).into() + } else { + // there are some trailing bytes at the end + let i8_array_type = context.i8_type().array_type(num_i8).as_basic_type_enum(); + + context + .struct_type(&[i64_array_type, i8_array_type], false) + .into() + } +} + /// Two usize values. Could be a wrapper for a List or a Str. /// /// It would be nicer if we could store this as a tuple containing one usize diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index f301c9626b..a69aefe045 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -472,7 +472,7 @@ mod gen_primitives { } #[test] - fn peano() { + fn peano1() { assert_evals_to!( indoc!( r#" @@ -482,8 +482,8 @@ mod gen_primitives { three = S (S (S Z)) when three is + Z -> 2 S _ -> 1 - Z -> 0 "# ), 1, @@ -511,4 +511,110 @@ mod gen_primitives { i64 ); } + + #[test] + fn linked_list_len_0() { + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + nil : LinkedList Int + nil = Nil + + length : LinkedList a -> Int + length = \list -> + when list is + Nil -> 0 + Cons _ rest -> 1 + length rest + + + length nil + "# + ), + 0, + i64 + ); + } + + #[test] + fn linked_list_len_3() { + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + three : LinkedList Int + three = Cons 3 (Cons 2 (Cons 1 Nil)) + + length : LinkedList a -> Int + length = \list -> + when list is + Nil -> 0 + Cons _ rest -> 1 + length rest + + + length three + "# + ), + 3, + i64 + ); + } + + #[test] + fn linked_list_sum() { + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + three : LinkedList Int + three = Cons 3 (Cons 2 (Cons 1 Nil)) + + sum : LinkedList a -> Int + sum = \list -> + when list is + Nil -> 0 + Cons x rest -> x + sum rest + + sum three + "# + ), + 3 + 2 + 1, + i64 + ); + } + + #[test] + #[ignore] + fn linked_list_map() { + // `f` is not actually a function, so the call to it fails currently + assert_evals_to!( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + three : LinkedList Int + three = Cons 3 (Cons 2 (Cons 1 Nil)) + + sum : LinkedList a -> Int + sum = \list -> + when list is + Nil -> 0 + Cons x rest -> x + sum rest + + map : (a -> b), LinkedList a -> LinkedList b + map = \f, list -> + when list is + Nil -> Nil + Cons x rest -> Cons (f x) (map f rest) + + sum (map (\_ -> 1) three) + "# + ), + 3, + i64 + ); + } } diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index ad8743c7a8..42834c10f9 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -184,7 +184,7 @@ pub fn helper_without_uniqueness<'a>( ); // Uncomment this to see the module's un-optimized LLVM instruction output: - env.module.print_to_stderr(); + // env.module.print_to_stderr(); if main_fn.verify(true) { function_pass.run_on(&main_fn); diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 42a23d5e36..44f0027d36 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -116,7 +116,7 @@ impl<'a> Layout<'a> { Ok(Layout::RecursivePointer) } else { let content = env.subs.get_without_compacting(var).content; - println!("{:?} {:?}", var, &content); + // println!("{:?} {:?}", var, &content); Self::new_help(env, content) } } @@ -355,7 +355,7 @@ fn layout_from_flat_type<'a>( // Num.Num should only ever have 1 argument, e.g. Num.Num Int.Integer debug_assert_eq!(args.len(), 1); - let var = args.iter().next().unwrap(); + let var = args.get(0).unwrap(); let content = subs.get_without_compacting(*var).content; layout_from_num_content(content) @@ -483,7 +483,7 @@ fn layout_from_flat_type<'a>( let mut tag_layout = Vec::with_capacity_in(variables.len() + 1, arena); // store the discriminant - tag_layout.push(Layout::Builtin(Builtin::Int8)); + tag_layout.push(Layout::Builtin(Builtin::Int64)); for var in variables { // TODO does this cause problems with mutually recursive unions? @@ -584,6 +584,9 @@ pub fn union_sorted_tags<'a>(arena: &'a Bump, var: Variable, subs: &Subs) -> Uni fn get_recursion_var(subs: &Subs, var: Variable) -> Option { match subs.get_without_compacting(var).content { Content::Structure(FlatType::RecursiveTagUnion(rec_var, _, _)) => Some(rec_var), + Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) => { + get_recursion_var(subs, args[1]) + } Content::Alias(_, _, actual) => get_recursion_var(subs, actual), _ => None, } diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index ff966d66e2..43b88cf220 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -1559,6 +1559,47 @@ mod test_mono { ) } + #[test] + fn peano1() { + compiles_to_ir( + indoc!( + r#" + Peano : [ S Peano, Z ] + + three : Peano + three = S (S (S Z)) + + when three is + Z -> 0 + S _ -> 1 + "# + ), + indoc!( + r#" + let Test.9 = 0i64; + let Test.11 = 0i64; + let Test.13 = 0i64; + let Test.15 = 1i64; + let Test.14 = Z Test.15; + let Test.12 = S Test.13 Test.14; + let Test.10 = S Test.11 Test.12; + let Test.1 = S Test.9 Test.10; + let Test.5 = true; + let Test.7 = Index 0 Test.1; + let Test.6 = 1i64; + let Test.8 = lowlevel Eq Test.6 Test.7; + let Test.4 = lowlevel And Test.8 Test.5; + if Test.4 then + let Test.2 = 0i64; + ret Test.2; + else + let Test.3 = 1i64; + ret Test.3; + "# + ), + ) + } + #[test] fn peano2() { compiles_to_ir( diff --git a/compiler/reporting/tests/test_reporting.rs b/compiler/reporting/tests/test_reporting.rs index 986494edc6..0e6390165f 100644 --- a/compiler/reporting/tests/test_reporting.rs +++ b/compiler/reporting/tests/test_reporting.rs @@ -3052,6 +3052,9 @@ mod test_reporting { #[test] fn two_different_cons() { + // TODO investigate what is happening here; + // while it makes some kind of sense to print the recursion var as infinite, + // it's not very helpful in practice. report_problem_as( indoc!( r#" @@ -3075,11 +3078,11 @@ mod test_reporting { This `Cons` global tag application has the type: - [ Cons {} [ Cons Str [ Cons {} a, Nil ] as a, Nil ], Nil ] + [ Cons {} [ Cons Str [ Cons {} ∞, Nil ] as ∞, Nil ], Nil ] But the type annotation on `x` says it should be: - [ Cons {} a, Nil ] as a + [ Cons {} ∞, Nil ] as ∞ "# ), ) diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index f917bdc998..f62deead83 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -565,43 +565,84 @@ fn unify_shared_tags( // and so on until the whole non-recursive tag union can be unified with it. let mut problems = Vec::new(); + let attr_wrapped = match (subs.get(expected).content, subs.get(actual).content) { + ( + Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, expected_args)), + Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, actual_args)), + ) => Some(( + expected_args[0], + expected_args[1], + actual_args[0], + actual_args[1], + )), + _ => None, + }; + if let Some(rvar) = recursion_var { - if expected == rvar { - problems.extend(unify_pool(subs, pool, actual, ctx.second)); - println!("A"); - } else if is_structure(actual, subs) { - // the recursion variable is hidden behind some structure (commonly an Attr - // with uniqueness inference). Thus we must expand the recursive tag union to - // unify if with the non-recursive one. Thus: + match attr_wrapped { + None => { + if expected == rvar { + if actual == rvar { + problems.extend(unify_pool(subs, pool, expected, actual)); + } else { + problems.extend(unify_pool(subs, pool, actual, ctx.second)); - // replace the rvar with ctx.second (the whole recursive tag union) in expected - subs.explicit_substitute(rvar, ctx.second, expected); + // this unification is required for layout generation, + // but causes worse error messages + problems.extend(unify_pool(subs, pool, expected, actual)); + } + } else if is_structure(actual, subs) { + // the recursion variable is hidden behind some structure (commonly an Attr + // with uniqueness inference). Thus we must expand the recursive tag union to + // unify if with the non-recursive one. Thus: - // but, by the `is_structure` condition above, only if we're unifying with a structure! - // when `actual` is just a flex/rigid variable, the substitution would expand a - // recursive tag union infinitely! + // replace the rvar with ctx.second (the whole recursive tag union) in expected + subs.explicit_substitute(rvar, ctx.second, expected); - problems.extend(unify_pool(subs, pool, actual, expected)); - println!("B"); - } else { - // unification with a non-structure is trivial - problems.extend(unify_pool(subs, pool, actual, expected)); - println!("C"); + // but, by the `is_structure` condition above, only if we're unifying with a structure! + // when `actual` is just a flex/rigid variable, the substitution would expand a + // recursive tag union infinitely! + + problems.extend(unify_pool(subs, pool, actual, expected)); + } else { + // unification with a non-structure is trivial + problems.extend(unify_pool(subs, pool, actual, expected)); + } + } + Some((_expected_uvar, inner_expected, _actual_uvar, inner_actual)) => { + if inner_expected == rvar { + if inner_actual == rvar { + problems.extend(unify_pool(subs, pool, actual, expected)); + } else { + problems.extend(unify_pool(subs, pool, inner_actual, ctx.second)); + problems.extend(unify_pool(subs, pool, expected, actual)); + } + } else if is_structure(inner_actual, subs) { + // the recursion variable is hidden behind some structure (commonly an Attr + // with uniqueness inference). Thus we must expand the recursive tag union to + // unify if with the non-recursive one. Thus: + + // replace the rvar with ctx.second (the whole recursive tag union) in expected + subs.explicit_substitute(rvar, ctx.second, inner_expected); + + // but, by the `is_structure` condition above, only if we're unifying with a structure! + // when `actual` is just a flex/rigid variable, the substitution would expand a + // recursive tag union infinitely! + + problems.extend(unify_pool(subs, pool, actual, expected)); + } else { + // unification with a non-structure is trivial + problems.extend(unify_pool(subs, pool, actual, expected)); + } + } } } else { // we always unify NonRecursive with Recursive, so this should never happen debug_assert_ne!(Some(actual), recursion_var); problems.extend(unify_pool(subs, pool, actual, expected)); - println!("D"); }; - // TODO this changes some error messages - // but is important for the inference of recursive types - if problems.is_empty() { - problems.extend(unify_pool(subs, pool, expected, actual)); - } - if problems.is_empty() { // debug_assert_eq!(subs.get_root_key(actual), subs.get_root_key(expected)); matching_vars.push(actual); @@ -695,8 +736,19 @@ fn unify_flat_type( unify_tag_union(subs, pool, ctx, union1, union2, (None, None)) } - (RecursiveTagUnion(_, _, _), TagUnion(_, _)) => { - unreachable!("unify of recursive with non-recursive tag union should not occur"); + (RecursiveTagUnion(recursion_var, tags1, ext1), TagUnion(tags2, ext2)) => { + // unreachable!("unify of recursive with non-recursive tag union should not occur"); + let union1 = gather_tags(subs, tags1.clone(), *ext1); + let union2 = gather_tags(subs, tags2.clone(), *ext2); + + unify_tag_union( + subs, + pool, + ctx, + union1, + union2, + (None, Some(*recursion_var)), + ) } (TagUnion(tags1, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => {