get RBTree.balance to compile

This commit is contained in:
Folkert 2020-11-14 02:49:28 +01:00
parent ec3868ed7e
commit 154b5cc29f
5 changed files with 104 additions and 83 deletions

View file

@ -752,31 +752,33 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
let mut field_types = Vec::with_capacity_in(num_fields, env.arena); let mut field_types = Vec::with_capacity_in(num_fields, env.arena);
let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); let mut field_vals = Vec::with_capacity_in(num_fields, env.arena);
for (field_symbol, tag_field_layout) in let tag_field_layouts = fields[*tag_id as usize];
arguments.iter().zip(fields[*tag_id as usize].iter()) for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) {
{ let val = load_symbol(env, scope, field_symbol);
// note field_layout is the layout of the argument.
// tag_field_layout is the layout that the tag will store
// these are different for recursive tag unions
let (val, field_layout) = load_symbol_and_layout(env, scope, field_symbol);
let field_size = tag_field_layout.stack_size(ptr_size);
// Zero-sized fields have no runtime representation. // Zero-sized fields have no runtime representation.
// The layout of the struct expects them to be dropped! // The layout of the struct expects them to be dropped!
if field_size != 0 { if !tag_field_layout.is_dropped_because_empty() {
let field_type = let field_type =
basic_type_from_layout(env.arena, env.context, tag_field_layout, ptr_size); basic_type_from_layout(env.arena, env.context, tag_field_layout, ptr_size);
field_types.push(field_type); field_types.push(field_type);
if let Layout::RecursivePointer = tag_field_layout { if let Layout::RecursivePointer = tag_field_layout {
let ptr = allocate_with_refcount(env, field_layout, val).into(); let ptr = allocate_with_refcount(env, &tag_layout, val).into();
let ptr = cast_basic_basic(
builder, builder.build_store(ptr, val);
ptr,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), let as_i64_ptr = cast_basic_basic(
env.builder,
ptr.into(),
env.context
.i64_type()
.ptr_type(AddressSpace::Generic)
.into(),
); );
field_vals.push(ptr);
field_vals.push(as_i64_ptr);
} else { } else {
field_vals.push(val); field_vals.push(val);
} }
@ -1010,7 +1012,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>(
// We must return a pointer to the first element: // We must return a pointer to the first element:
let ptr_bytes = env.ptr_bytes; let ptr_bytes = env.ptr_bytes;
let int_type = ptr_int(ctx, ptr_bytes); let int_type = ptr_int(ctx, ptr_bytes);
let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "list_cast_ptr"); let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "allocate_refcount_pti");
let incremented = builder.build_int_add( let incremented = builder.build_int_add(
ptr_as_int, ptr_as_int,
ctx.i64_type().const_int(offset, false), ctx.i64_type().const_int(offset, false),
@ -1018,7 +1020,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>(
); );
let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic); let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic);
let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "list_cast_ptr"); let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "allocate_refcount_itp");
// subtract ptr_size, to access the refcount // subtract ptr_size, to access the refcount
let refcount_ptr = builder.build_int_sub( let refcount_ptr = builder.build_int_sub(

View file

@ -89,46 +89,7 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>(
RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"), RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"),
Union(tags) => { Union(tags) => {
debug_assert!(!tags.is_empty()); build_dec_union(env, layout_ids, tags, value);
let wrapper_struct = value.into_struct_value();
// read the tag_id
let tag_id = env
.builder
.build_extract_value(wrapper_struct, 0, "read_tag_id")
.unwrap()
.into_int_value();
// next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
let merge_block = env.context.append_basic_block(parent, "decrement_merge");
for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_decrement");
env.builder.position_at_end(block);
for (i, field_layout) in field_layouts.iter().enumerate() {
if field_layout.contains_refcounted() {
let field_ptr = env
.builder
.build_extract_value(wrapper_struct, i as u32, "decrement_struct_field")
.unwrap();
decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout)
}
}
env.builder.build_unconditional_branch(merge_block);
cases.push((env.context.i8_type().const_int(tag_id as u64, false), block));
}
let (_, default_block) = cases.pop().unwrap();
env.builder.build_switch(tag_id, default_block, &cases);
env.builder.position_at_end(merge_block);
} }
RecursiveUnion(tags) => { RecursiveUnion(tags) => {
@ -906,14 +867,20 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
let wrapper_struct = arg_val.into_struct_value(); let wrapper_struct = arg_val.into_struct_value();
// let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into());
// next, make a jump table for all possible values of the tag_id // next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena); let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
let merge_block = env.context.append_basic_block(parent, "decrement_merge"); let merge_block = env.context.append_basic_block(parent, "decrement_merge");
for (tag_id, field_layouts) in tags.iter().enumerate() { for (tag_id, field_layouts) in tags.iter().enumerate() {
// if none of the fields are or contain anything refcounted, just move on
if !field_layouts
.iter()
.any(|x| x.is_refcounted() || x.contains_refcounted())
{
continue;
}
let block = env.context.append_basic_block(parent, "tag_id_decrement"); let block = env.context.append_basic_block(parent, "tag_id_decrement");
env.builder.position_at_end(block); env.builder.position_at_end(block);
@ -981,8 +948,6 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
cases.reverse(); cases.reverse();
let (_, default_block) = cases.pop().unwrap();
env.builder.position_at_end(before_block); env.builder.position_at_end(before_block);
// read the tag_id // read the tag_id
@ -1002,7 +967,7 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
// switch on it // switch on it
env.builder env.builder
.build_switch(current_tag_id, default_block, &cases); .build_switch(current_tag_id, merge_block, &cases);
env.builder.position_at_end(merge_block); env.builder.position_at_end(merge_block);
@ -1109,10 +1074,18 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
// next, make a jump table for all possible values of the tag_id // next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena); let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
let merge_block = env.context.append_basic_block(parent, "decrement_merge"); let merge_block = env.context.append_basic_block(parent, "increment_merge");
for (tag_id, field_layouts) in tags.iter().enumerate() { for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_decrement"); // if none of the fields are or contain anything refcounted, just move on
if !field_layouts
.iter()
.any(|x| x.is_refcounted() || x.contains_refcounted())
{
continue;
}
let block = env.context.append_basic_block(parent, "tag_id_increment");
env.builder.position_at_end(block); env.builder.position_at_end(block);
let wrapper_type = basic_type_from_layout( let wrapper_type = basic_type_from_layout(
@ -1127,18 +1100,19 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
// a *i64 pointer to the recursive data // this field has type `*i64`, but is really a pointer to the data we want
// we need to cast this pointer to the appropriate type let ptr_as_i64_ptr = env
let field_ptr = env
.builder .builder
.build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") .build_extract_value(wrapper_struct, i as u32, "increment_struct_field")
.unwrap(); .unwrap();
// recursively increment debug_assert!(ptr_as_i64_ptr.is_pointer_value());
// therefore we must cast it to our desired type
let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); let union_type = block_of_memory(env.context, &layout, env.ptr_bytes);
let recursive_field_ptr = cast_basic_basic( let recursive_field_ptr = cast_basic_basic(
env.builder, env.builder,
field_ptr, ptr_as_i64_ptr,
union_type.ptr_type(AddressSpace::Generic).into(), union_type.ptr_type(AddressSpace::Generic).into(),
) )
.into_pointer_value(); .into_pointer_value();
@ -1155,9 +1129,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
// Because it's an internal-only function, use the fast calling convention. // Because it's an internal-only function, use the fast calling convention.
call.set_call_convention(FAST_CALL_CONV); call.set_call_convention(FAST_CALL_CONV);
// TODO do this increment before the recursive call? // TODO do this decrement before the recursive call?
// Then the recursive call is potentially TCE'd // Then the recursive call is potentially TCE'd
increment_refcount_ptr(env, &layout, field_ptr.into_pointer_value()); increment_refcount_ptr(env, &layout, recursive_field_ptr);
} else if field_layout.contains_refcounted() { } else if field_layout.contains_refcounted() {
let field_ptr = env let field_ptr = env
.builder .builder
@ -1173,12 +1147,10 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
cases.push((env.context.i8_type().const_int(tag_id as u64, false), block)); cases.push((env.context.i8_type().const_int(tag_id as u64, false), block));
} }
let (_, default_block) = cases.pop().unwrap();
env.builder.position_at_end(before_block); env.builder.position_at_end(before_block);
env.builder env.builder
.build_switch(tag_id_u8.into_int_value(), default_block, &cases); .build_switch(tag_id_u8.into_int_value(), merge_block, &cases);
env.builder.position_at_end(merge_block); env.builder.position_at_end(merge_block);

View file

@ -1278,7 +1278,6 @@ mod gen_primitives {
} }
#[test] #[test]
#[ignore]
fn rbtree_balance() { fn rbtree_balance() {
assert_non_opt_evals_to!( assert_non_opt_evals_to!(
indoc!( indoc!(
@ -1289,18 +1288,38 @@ mod gen_primitives {
Dict k v : [ Node NodeColor k v (Dict k v) (Dict k v), Empty ] Dict k v : [ Node NodeColor k v (Dict k v) (Dict k v), Empty ]
Key k : Num k
balance : NodeColor, k, v, Dict k v, Dict k v -> Dict k v balance : NodeColor, k, v, Dict k v, Dict k v -> Dict k v
balance = \color, key, value, left, right -> balance = \color, key, value, left, right ->
when right is when right is
Node Red lK lV (Node Red llK llV llLeft llRight) lRight -> Empty Node Red rK rV rLeft rRight ->
Empty -> Empty when left is
Node Red lK lV lLeft lRight ->
Node
Red
key
value
(Node Black lK lV lLeft lRight)
(Node Black rK rV rLeft rRight)
_ ->
Node color rK rV (Node Red key value left rLeft) rRight
main : Dict Int {} _ ->
when left is
Node Red lK lV (Node Red llK llV llLeft llRight) lRight ->
Node
Red
lK
lV
(Node Black llK llV llLeft llRight)
(Node Black key value lRight right)
_ ->
Node color key value left right
main : Dict Int Int
main = main =
balance Red 0 {} Empty Empty balance Red 0 0 Empty Empty
"# "#
), ),
1, 1,
@ -1309,7 +1328,33 @@ mod gen_primitives {
} }
#[test] #[test]
#[ignore] fn linked_list_guarded_double_pattern_match() {
// the important part here is that the first case (with the nested Cons) does not match
assert_non_opt_evals_to!(
indoc!(
r#"
app Test provides [ main ] imports []
ConsList a : [ Cons a (ConsList a), Nil ]
balance : ConsList Int -> Int
balance = \right ->
when right is
Cons 1 (Cons 1 _) -> 3
_ -> 3
main : Int
main =
when balance Nil is
_ -> 3
"#
),
3,
i64
);
}
#[test]
fn linked_list_double_pattern_match() { fn linked_list_double_pattern_match() {
assert_non_opt_evals_to!( assert_non_opt_evals_to!(
indoc!( indoc!(

View file

@ -4398,6 +4398,7 @@ fn store_pattern<'a>(
field_layouts: arg_layouts.clone().into_bump_slice(), field_layouts: arg_layouts.clone().into_bump_slice(),
structure: outer_symbol, structure: outer_symbol,
}; };
match argument { match argument {
Identifier(symbol) => { Identifier(symbol) => {
// store immediately in the given symbol // store immediately in the given symbol

View file

@ -435,6 +435,7 @@ impl<'a> Layout<'a> {
match self { match self {
Layout::Builtin(Builtin::List(_, _)) => true, Layout::Builtin(Builtin::List(_, _)) => true,
Layout::RecursiveUnion(_) => true, Layout::RecursiveUnion(_) => true,
Layout::RecursivePointer => true,
_ => false, _ => false,
} }
} }