From 154b5cc29f7f2da24e5f7d51dda37f272e04d0b0 Mon Sep 17 00:00:00 2001 From: Folkert Date: Sat, 14 Nov 2020 02:49:28 +0100 Subject: [PATCH] get RBTree.balance to compile --- compiler/gen/src/llvm/build.rs | 36 ++++++------ compiler/gen/src/llvm/refcounting.rs | 88 ++++++++++------------------ compiler/gen/tests/gen_primitives.rs | 61 ++++++++++++++++--- compiler/mono/src/ir.rs | 1 + compiler/mono/src/layout.rs | 1 + 5 files changed, 104 insertions(+), 83 deletions(-) diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 209dc6e02b..c644a75ca5 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -752,31 +752,33 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( let mut field_types = Vec::with_capacity_in(num_fields, env.arena); let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); - for (field_symbol, tag_field_layout) in - arguments.iter().zip(fields[*tag_id as usize].iter()) - { - // note field_layout is the layout of the argument. - // tag_field_layout is the layout that the tag will store - // these are different for recursive tag unions - let (val, field_layout) = load_symbol_and_layout(env, scope, field_symbol); - let field_size = tag_field_layout.stack_size(ptr_size); + let tag_field_layouts = fields[*tag_id as usize]; + for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) { + let val = load_symbol(env, scope, field_symbol); // Zero-sized fields have no runtime representation. // The layout of the struct expects them to be dropped! - if field_size != 0 { + if !tag_field_layout.is_dropped_because_empty() { let field_type = basic_type_from_layout(env.arena, env.context, tag_field_layout, ptr_size); field_types.push(field_type); if let Layout::RecursivePointer = tag_field_layout { - let ptr = allocate_with_refcount(env, field_layout, val).into(); - let ptr = cast_basic_basic( - builder, - ptr, - ctx.i64_type().ptr_type(AddressSpace::Generic).into(), + let ptr = allocate_with_refcount(env, &tag_layout, val).into(); + + builder.build_store(ptr, val); + + let as_i64_ptr = cast_basic_basic( + env.builder, + ptr.into(), + env.context + .i64_type() + .ptr_type(AddressSpace::Generic) + .into(), ); - field_vals.push(ptr); + + field_vals.push(as_i64_ptr); } else { field_vals.push(val); } @@ -1010,7 +1012,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>( // We must return a pointer to the first element: let ptr_bytes = env.ptr_bytes; let int_type = ptr_int(ctx, ptr_bytes); - let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "list_cast_ptr"); + let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "allocate_refcount_pti"); let incremented = builder.build_int_add( ptr_as_int, ctx.i64_type().const_int(offset, false), @@ -1018,7 +1020,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>( ); let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic); - let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "list_cast_ptr"); + let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "allocate_refcount_itp"); // subtract ptr_size, to access the refcount let refcount_ptr = builder.build_int_sub( diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index fa5c877303..9ae829bc0c 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -89,46 +89,7 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>( RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"), Union(tags) => { - debug_assert!(!tags.is_empty()); - let wrapper_struct = value.into_struct_value(); - - // read the tag_id - let tag_id = env - .builder - .build_extract_value(wrapper_struct, 0, "read_tag_id") - .unwrap() - .into_int_value(); - - // next, make a jump table for all possible values of the tag_id - let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - - let merge_block = env.context.append_basic_block(parent, "decrement_merge"); - - for (tag_id, field_layouts) in tags.iter().enumerate() { - let block = env.context.append_basic_block(parent, "tag_id_decrement"); - env.builder.position_at_end(block); - - for (i, field_layout) in field_layouts.iter().enumerate() { - if field_layout.contains_refcounted() { - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") - .unwrap(); - - decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout) - } - } - - env.builder.build_unconditional_branch(merge_block); - - cases.push((env.context.i8_type().const_int(tag_id as u64, false), block)); - } - - let (_, default_block) = cases.pop().unwrap(); - - env.builder.build_switch(tag_id, default_block, &cases); - - env.builder.position_at_end(merge_block); + build_dec_union(env, layout_ids, tags, value); } RecursiveUnion(tags) => { @@ -906,14 +867,20 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( let wrapper_struct = arg_val.into_struct_value(); - // let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); - // next, make a jump table for all possible values of the tag_id let mut cases = Vec::with_capacity_in(tags.len(), env.arena); let merge_block = env.context.append_basic_block(parent, "decrement_merge"); for (tag_id, field_layouts) in tags.iter().enumerate() { + // if none of the fields are or contain anything refcounted, just move on + if !field_layouts + .iter() + .any(|x| x.is_refcounted() || x.contains_refcounted()) + { + continue; + } + let block = env.context.append_basic_block(parent, "tag_id_decrement"); env.builder.position_at_end(block); @@ -981,8 +948,6 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( cases.reverse(); - let (_, default_block) = cases.pop().unwrap(); - env.builder.position_at_end(before_block); // read the tag_id @@ -1002,7 +967,7 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( // switch on it env.builder - .build_switch(current_tag_id, default_block, &cases); + .build_switch(current_tag_id, merge_block, &cases); env.builder.position_at_end(merge_block); @@ -1109,10 +1074,18 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( // next, make a jump table for all possible values of the tag_id let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - let merge_block = env.context.append_basic_block(parent, "decrement_merge"); + let merge_block = env.context.append_basic_block(parent, "increment_merge"); for (tag_id, field_layouts) in tags.iter().enumerate() { - let block = env.context.append_basic_block(parent, "tag_id_decrement"); + // if none of the fields are or contain anything refcounted, just move on + if !field_layouts + .iter() + .any(|x| x.is_refcounted() || x.contains_refcounted()) + { + continue; + } + + let block = env.context.append_basic_block(parent, "tag_id_increment"); env.builder.position_at_end(block); let wrapper_type = basic_type_from_layout( @@ -1127,18 +1100,19 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { - // a *i64 pointer to the recursive data - // we need to cast this pointer to the appropriate type - let field_ptr = env + // this field has type `*i64`, but is really a pointer to the data we want + let ptr_as_i64_ptr = env .builder - .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") + .build_extract_value(wrapper_struct, i as u32, "increment_struct_field") .unwrap(); - // recursively increment + debug_assert!(ptr_as_i64_ptr.is_pointer_value()); + + // therefore we must cast it to our desired type let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); let recursive_field_ptr = cast_basic_basic( env.builder, - field_ptr, + ptr_as_i64_ptr, union_type.ptr_type(AddressSpace::Generic).into(), ) .into_pointer_value(); @@ -1155,9 +1129,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( // Because it's an internal-only function, use the fast calling convention. call.set_call_convention(FAST_CALL_CONV); - // TODO do this increment before the recursive call? + // TODO do this decrement before the recursive call? // Then the recursive call is potentially TCE'd - increment_refcount_ptr(env, &layout, field_ptr.into_pointer_value()); + increment_refcount_ptr(env, &layout, recursive_field_ptr); } else if field_layout.contains_refcounted() { let field_ptr = env .builder @@ -1173,12 +1147,10 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( cases.push((env.context.i8_type().const_int(tag_id as u64, false), block)); } - let (_, default_block) = cases.pop().unwrap(); - env.builder.position_at_end(before_block); env.builder - .build_switch(tag_id_u8.into_int_value(), default_block, &cases); + .build_switch(tag_id_u8.into_int_value(), merge_block, &cases); env.builder.position_at_end(merge_block); diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 0a693b818d..58bec5323e 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -1278,7 +1278,6 @@ mod gen_primitives { } #[test] - #[ignore] fn rbtree_balance() { assert_non_opt_evals_to!( indoc!( @@ -1289,18 +1288,38 @@ mod gen_primitives { Dict k v : [ Node NodeColor k v (Dict k v) (Dict k v), Empty ] - Key k : Num k - balance : NodeColor, k, v, Dict k v, Dict k v -> Dict k v balance = \color, key, value, left, right -> when right is - Node Red lK lV (Node Red llK llV llLeft llRight) lRight -> Empty - Empty -> Empty + Node Red rK rV rLeft rRight -> + when left is + Node Red lK lV lLeft lRight -> + Node + Red + key + value + (Node Black lK lV lLeft lRight) + (Node Black rK rV rLeft rRight) + _ -> + Node color rK rV (Node Red key value left rLeft) rRight - main : Dict Int {} + _ -> + when left is + Node Red lK lV (Node Red llK llV llLeft llRight) lRight -> + Node + Red + lK + lV + (Node Black llK llV llLeft llRight) + (Node Black key value lRight right) + + _ -> + Node color key value left right + + main : Dict Int Int main = - balance Red 0 {} Empty Empty + balance Red 0 0 Empty Empty "# ), 1, @@ -1309,7 +1328,33 @@ mod gen_primitives { } #[test] - #[ignore] + fn linked_list_guarded_double_pattern_match() { + // the important part here is that the first case (with the nested Cons) does not match + assert_non_opt_evals_to!( + indoc!( + r#" + app Test provides [ main ] imports [] + + ConsList a : [ Cons a (ConsList a), Nil ] + + balance : ConsList Int -> Int + balance = \right -> + when right is + Cons 1 (Cons 1 _) -> 3 + _ -> 3 + + main : Int + main = + when balance Nil is + _ -> 3 + "# + ), + 3, + i64 + ); + } + + #[test] fn linked_list_double_pattern_match() { assert_non_opt_evals_to!( indoc!( diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index e52641762b..edbae1195b 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -4398,6 +4398,7 @@ fn store_pattern<'a>( field_layouts: arg_layouts.clone().into_bump_slice(), structure: outer_symbol, }; + match argument { Identifier(symbol) => { // store immediately in the given symbol diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index ea55789a6e..50b572c76c 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -435,6 +435,7 @@ impl<'a> Layout<'a> { match self { Layout::Builtin(Builtin::List(_, _)) => true, Layout::RecursiveUnion(_) => true, + Layout::RecursivePointer => true, _ => false, } }