change over nullable wrapped

This commit is contained in:
Folkert 2021-06-27 01:35:15 +02:00
parent 71857e83d0
commit 8add147dcf
5 changed files with 230 additions and 152 deletions

View file

@ -1043,14 +1043,15 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
// This tricks comes from // This tricks comes from
// https://github.com/raviqqe/ssf/blob/bc32aae68940d5bddf5984128e85af75ca4f4686/ssf-llvm/src/expression_compiler.rs#L116 // https://github.com/raviqqe/ssf/blob/bc32aae68940d5bddf5984128e85af75ca4f4686/ssf-llvm/src/expression_compiler.rs#L116
let internal_type = block_of_memory(env.context, layout, env.ptr_bytes); let internal_type = block_of_memory_slices(env.context, fields, env.ptr_bytes);
let data = cast_tag_to_block_of_memory(builder, struct_val, internal_type); let data = cast_tag_to_block_of_memory(builder, struct_val, internal_type);
let tag_id_type = env.context.i64_type();
let wrapper_type = env let wrapper_type = env
.context .context
.struct_type(&[data.get_type(), env.context.i64_type().into()], false); .struct_type(&[data.get_type(), tag_id_type.into()], false);
let tag_id_intval = env.context.i64_type().const_int(*tag_id as u64, false); let tag_id_intval = tag_id_type.const_int(*tag_id as u64, false);
let field_vals = [ let field_vals = [
(TAG_ID_INDEX as usize, tag_id_intval.into()), (TAG_ID_INDEX as usize, tag_id_intval.into()),
@ -1061,7 +1062,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
} }
Tag { Tag {
arguments, arguments,
tag_layout: UnionLayout::Recursive(fields), tag_layout: UnionLayout::Recursive(tags),
union_size, union_size,
tag_id, tag_id,
.. ..
@ -1076,7 +1077,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
let mut field_types = Vec::with_capacity_in(num_fields, env.arena); let mut field_types = Vec::with_capacity_in(num_fields, env.arena);
let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); let mut field_vals = Vec::with_capacity_in(num_fields, env.arena);
let tag_field_layouts = &fields[*tag_id as usize]; let tag_field_layouts = &tags[*tag_id as usize];
for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) { for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) {
let (val, _val_layout) = load_symbol_and_layout(scope, field_symbol); let (val, _val_layout) = load_symbol_and_layout(scope, field_symbol);
@ -1105,28 +1106,40 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
} }
// Create the struct_type // Create the struct_type
let data_ptr = reserve_with_refcount_union_as_block_of_memory(env, fields); let raw_data_ptr = reserve_with_refcount_union_as_block_of_memory2(env, tags);
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let tag_id_ptr = builder
.build_struct_gep(raw_data_ptr, TAG_ID_INDEX, "tag_id_index")
.unwrap();
let tag_id_type = env.context.i64_type();
env.builder
.build_store(tag_id_ptr, tag_id_type.const_int(*tag_id as u64, false));
let opaque_struct_ptr = builder
.build_struct_gep(raw_data_ptr, TAG_DATA_INDEX, "tag_data_index")
.unwrap();
let struct_type = env.context.struct_type(&field_types, false);
let struct_ptr = env let struct_ptr = env
.builder .builder
.build_bitcast( .build_bitcast(
data_ptr, opaque_struct_ptr,
struct_type.ptr_type(AddressSpace::Generic), struct_type.ptr_type(AddressSpace::Generic),
"block_of_memory_to_tag", "struct_ptr",
) )
.into_pointer_value(); .into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
let field_ptr = builder let field_ptr = builder
.build_struct_gep(struct_ptr, index as u32, "struct_gep") .build_struct_gep(struct_ptr, index as u32, "field_struct_gep")
.unwrap(); .unwrap();
builder.build_store(field_ptr, field_val); builder.build_store(field_ptr, field_val);
} }
data_ptr.into() raw_data_ptr.into()
} }
Tag { Tag {
@ -1207,17 +1220,22 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
tag_layout: tag_layout:
UnionLayout::NullableWrapped { UnionLayout::NullableWrapped {
nullable_id, nullable_id,
other_tags: fields, other_tags: tags,
}, },
union_size, union_size,
tag_id, tag_id,
.. ..
} => { } => {
let tag_struct_type = block_of_memory_slices(env.context, fields, env.ptr_bytes);
if *tag_id == *nullable_id as u8 { if *tag_id == *nullable_id as u8 {
let output_type = tag_struct_type.ptr_type(AddressSpace::Generic); let layout = Layout::Union(UnionLayout::NullableWrapped {
nullable_id: *nullable_id,
other_tags: tags,
});
return output_type.const_null().into(); return basic_type_from_layout(env, &layout)
.into_pointer_type()
.const_null()
.into();
} }
debug_assert!(*union_size > 1); debug_assert!(*union_size > 1);
@ -1233,9 +1251,9 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
let tag_field_layouts = { let tag_field_layouts = {
use std::cmp::Ordering::*; use std::cmp::Ordering::*;
match tag_id.cmp(&(*nullable_id as u8)) { match tag_id.cmp(&(*nullable_id as u8)) {
Equal => &[] as &[_], Equal => unreachable!("early return above"),
Less => &fields[*tag_id as usize], Less => &tags[*tag_id as usize],
Greater => &fields[*tag_id as usize - 1], Greater => &tags[*tag_id as usize - 1],
} }
}; };
@ -1270,28 +1288,40 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
} }
// Create the struct_type // Create the struct_type
let data_ptr = reserve_with_refcount_union_as_block_of_memory(env, fields); let raw_data_ptr = reserve_with_refcount_union_as_block_of_memory2(env, tags);
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let tag_id_ptr = builder
.build_struct_gep(raw_data_ptr, TAG_ID_INDEX, "tag_id_index")
.unwrap();
let tag_id_type = env.context.i64_type();
env.builder
.build_store(tag_id_ptr, tag_id_type.const_int(*tag_id as u64, false));
let opaque_struct_ptr = builder
.build_struct_gep(raw_data_ptr, TAG_DATA_INDEX, "tag_data_index")
.unwrap();
let struct_type = env.context.struct_type(&field_types, false);
let struct_ptr = env let struct_ptr = env
.builder .builder
.build_bitcast( .build_bitcast(
data_ptr, opaque_struct_ptr,
struct_type.ptr_type(AddressSpace::Generic), struct_type.ptr_type(AddressSpace::Generic),
"block_of_memory_to_tag", "struct_ptr",
) )
.into_pointer_value(); .into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
let field_ptr = builder let field_ptr = builder
.build_struct_gep(struct_ptr, index as u32, "struct_gep") .build_struct_gep(struct_ptr, index as u32, "field_struct_gep")
.unwrap(); .unwrap();
builder.build_store(field_ptr, field_val); builder.build_store(field_ptr, field_val);
} }
data_ptr.into() raw_data_ptr.into()
} }
Tag { Tag {
@ -1499,17 +1529,12 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(argument.is_pointer_value()); debug_assert!(argument.is_pointer_value());
let field_layouts = tag_layouts[*tag_id as usize]; let field_layouts = tag_layouts[*tag_id as usize];
let struct_layout = Layout::Struct(field_layouts);
let struct_type = basic_type_from_layout(env, &struct_layout); lookup_at_index_ptr2(
lookup_at_index_ptr(
env, env,
field_layouts, field_layouts,
*index as usize, *index as usize,
argument.into_pointer_value(), argument.into_pointer_value(),
struct_type.into_struct_type(),
&struct_layout,
) )
} }
UnionLayout::NonNullableUnwrapped(field_layouts) => { UnionLayout::NonNullableUnwrapped(field_layouts) => {
@ -1540,17 +1565,12 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
}; };
let field_layouts = other_tags[tag_index as usize]; let field_layouts = other_tags[tag_index as usize];
let struct_layout = Layout::Struct(field_layouts);
let struct_type = basic_type_from_layout(env, &struct_layout); lookup_at_index_ptr2(
lookup_at_index_ptr(
env, env,
field_layouts, field_layouts,
*index as usize, *index as usize,
argument.into_pointer_value(), argument.into_pointer_value(),
struct_type.into_struct_type(),
&struct_layout,
) )
} }
UnionLayout::NullableUnwrapped { UnionLayout::NullableUnwrapped {
@ -1606,13 +1626,7 @@ pub fn get_tag_id<'a, 'ctx, 'env>(
.unwrap() .unwrap()
} }
UnionLayout::Recursive(_) => { UnionLayout::Recursive(_) => {
let pointer = argument.into_pointer_value(); extract_tag_discriminant_ptr2(env, argument.into_pointer_value()).into()
let tag_id_pointer = builder.build_bitcast(
pointer,
env.context.i64_type().ptr_type(AddressSpace::Generic),
"tag_id_pointer",
);
builder.build_load(tag_id_pointer.into_pointer_value(), "load_tag_id")
} }
UnionLayout::NonNullableUnwrapped(_) => env.context.i64_type().const_zero().into(), UnionLayout::NonNullableUnwrapped(_) => env.context.i64_type().const_zero().into(),
UnionLayout::NullableWrapped { nullable_id, .. } => { UnionLayout::NullableWrapped { nullable_id, .. } => {
@ -1638,7 +1652,7 @@ pub fn get_tag_id<'a, 'ctx, 'env>(
{ {
env.builder.position_at_end(else_block); env.builder.position_at_end(else_block);
let tag_id = extract_tag_discriminant_ptr(env, argument_ptr); let tag_id = extract_tag_discriminant_ptr2(env, argument_ptr);
env.builder.build_store(result, tag_id); env.builder.build_store(result, tag_id);
env.builder.build_unconditional_branch(cont_block); env.builder.build_unconditional_branch(cont_block);
} }
@ -1701,6 +1715,62 @@ fn lookup_at_index_ptr<'a, 'ctx, 'env>(
} }
} }
fn lookup_at_index_ptr2<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
field_layouts: &[Layout<'_>],
index: usize,
value: PointerValue<'ctx>,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let struct_layout = Layout::Struct(field_layouts);
let struct_type = basic_type_from_layout(env, &struct_layout);
let tag_id_type = env.context.i64_type();
let wrapper_type = env
.context
.struct_type(&[struct_type, tag_id_type.into()], false);
let ptr = env
.builder
.build_bitcast(
value,
wrapper_type.ptr_type(AddressSpace::Generic),
"cast_lookup_at_index_ptr",
)
.into_pointer_value();
let data_ptr = builder
.build_struct_gep(ptr, TAG_DATA_INDEX, "at_index_struct_gep")
.unwrap();
let elem_ptr = builder
.build_struct_gep(data_ptr, index as u32, "at_index_struct_gep")
.unwrap();
let result = builder.build_load(elem_ptr, "load_at_index_ptr");
if let Some(Layout::RecursivePointer) = field_layouts.get(index as usize) {
// a recursive field is stored as a `i64*`, to use it we must cast it to
// a pointer to the block of memory representation
let struct_type = block_of_memory_slices(env.context, &[field_layouts], env.ptr_bytes);
let tag_id_type = env.context.i64_type();
let opaque_wrapper_type = env
.context
.struct_type(&[struct_type, tag_id_type.into()], false);
builder.build_bitcast(
result,
opaque_wrapper_type.ptr_type(AddressSpace::Generic),
"cast_rec_pointer_lookup_at_index_ptr",
)
} else {
result
}
}
pub fn reserve_with_refcount<'a, 'ctx, 'env>( pub fn reserve_with_refcount<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout: &Layout<'a>, layout: &Layout<'a>,
@ -1735,6 +1805,35 @@ fn reserve_with_refcount_union_as_block_of_memory<'a, 'ctx, 'env>(
reserve_with_refcount_help(env, basic_type, stack_size, alignment_bytes) reserve_with_refcount_help(env, basic_type, stack_size, alignment_bytes)
} }
fn reserve_with_refcount_union_as_block_of_memory2<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
fields: &[&[Layout<'a>]],
) -> PointerValue<'ctx> {
let block_type = block_of_memory_slices(env.context, fields, env.ptr_bytes);
let tag_id_type = env.context.i64_type();
let basic_type = env
.context
.struct_type(&[block_type, tag_id_type.into()], false);
let stack_size = fields
.iter()
.map(|tag| tag.iter().map(|l| l.stack_size(env.ptr_bytes)).sum())
.max()
.unwrap_or(0)
// add tag id
+ env.ptr_bytes;
let alignment_bytes = fields
.iter()
.map(|tag| tag.iter().map(|l| l.alignment_bytes(env.ptr_bytes)))
.flatten()
.max()
.unwrap_or(0);
reserve_with_refcount_help(env, basic_type, stack_size, alignment_bytes)
}
fn reserve_with_refcount_help<'a, 'ctx, 'env>( fn reserve_with_refcount_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
basic_type: impl BasicType<'ctx>, basic_type: impl BasicType<'ctx>,
@ -2495,95 +2594,21 @@ pub fn extract_tag_discriminant<'a, 'ctx, 'env>(
union_layout: UnionLayout<'a>, union_layout: UnionLayout<'a>,
cond_value: BasicValueEnum<'ctx>, cond_value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let builder = env.builder; get_tag_id(env, parent, &union_layout, cond_value).into_int_value()
match union_layout {
UnionLayout::NonRecursive(_) => {
let pointer = builder.build_alloca(cond_value.get_type(), "get_type");
builder.build_store(pointer, cond_value);
let tag_id_pointer = builder.build_bitcast(
pointer,
env.context.i64_type().ptr_type(AddressSpace::Generic),
"tag_id_pointer",
);
builder
.build_load(tag_id_pointer.into_pointer_value(), "load_tag_id")
.into_int_value()
}
UnionLayout::Recursive(_) => {
let pointer = cond_value.into_pointer_value();
let tag_id_pointer = builder.build_bitcast(
pointer,
env.context.i64_type().ptr_type(AddressSpace::Generic),
"tag_id_pointer",
);
builder
.build_load(tag_id_pointer.into_pointer_value(), "load_tag_id")
.into_int_value()
}
UnionLayout::NonNullableUnwrapped(_) => env.context.i64_type().const_zero(),
UnionLayout::NullableWrapped { nullable_id, .. } => {
let argument_ptr = cond_value.into_pointer_value();
let is_null = env.builder.build_is_null(argument_ptr, "is_null");
let ctx = env.context;
let then_block = ctx.append_basic_block(parent, "then");
let else_block = ctx.append_basic_block(parent, "else");
let cont_block = ctx.append_basic_block(parent, "cont");
let result = builder.build_alloca(ctx.i64_type(), "result");
env.builder
.build_conditional_branch(is_null, then_block, else_block);
{
env.builder.position_at_end(then_block);
let tag_id = ctx.i64_type().const_int(nullable_id as u64, false);
env.builder.build_store(result, tag_id);
env.builder.build_unconditional_branch(cont_block);
}
{
env.builder.position_at_end(else_block);
let tag_id = extract_tag_discriminant_ptr(env, argument_ptr);
env.builder.build_store(result, tag_id);
env.builder.build_unconditional_branch(cont_block);
}
env.builder.position_at_end(cont_block);
env.builder
.build_load(result, "load_result")
.into_int_value()
}
UnionLayout::NullableUnwrapped { nullable_id, .. } => {
let argument_ptr = cond_value.into_pointer_value();
let is_null = env.builder.build_is_null(argument_ptr, "is_null");
let ctx = env.context;
let then_value = ctx.i64_type().const_int(nullable_id as u64, false);
let else_value = ctx.i64_type().const_int(!nullable_id as u64, false);
env.builder
.build_select(is_null, then_value, else_value, "select_tag_id")
.into_int_value()
}
}
} }
fn extract_tag_discriminant_ptr<'a, 'ctx, 'env>( fn extract_tag_discriminant_ptr2<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
from_value: PointerValue<'ctx>, from_value: PointerValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let tag_id_ptr_type = env.context.i64_type().ptr_type(AddressSpace::Generic); let tag_id_ptr = env
let ptr = env
.builder .builder
.build_bitcast(from_value, tag_id_ptr_type, "extract_tag_discriminant_ptr") .build_struct_gep(from_value, TAG_ID_INDEX, "tag_id_ptr")
.into_pointer_value(); .unwrap();
env.builder.build_load(ptr, "load_tag_id").into_int_value() env.builder
.build_load(tag_id_ptr, "load_tag_id")
.into_int_value()
} }
struct SwitchArgsIr<'a, 'ctx> { struct SwitchArgsIr<'a, 'ctx> {

View file

@ -1,7 +1,9 @@
use crate::debug_info_init; use crate::debug_info_init;
use crate::llvm::bitcode::call_bitcode_fn; use crate::llvm::bitcode::call_bitcode_fn;
use crate::llvm::build::Env; use crate::llvm::build::Env;
use crate::llvm::build::{cast_block_of_memory_to_tag, complex_bitcast, FAST_CALL_CONV}; use crate::llvm::build::{
cast_block_of_memory_to_tag, complex_bitcast, get_tag_id, FAST_CALL_CONV,
};
use crate::llvm::build_str; use crate::llvm::build_str;
use crate::llvm::convert::basic_type_from_layout; use crate::llvm::convert::basic_type_from_layout;
use bumpalo::collections::Vec; use bumpalo::collections::Vec;
@ -406,7 +408,6 @@ fn hash_tag<'a, 'ctx, 'env>(
env.builder.position_at_end(entry_block); env.builder.position_at_end(entry_block);
match union_layout { match union_layout {
NonRecursive(tags) => { NonRecursive(tags) => {
// SAFETY we know that non-recursive tags cannot be NULL
let tag_id = nonrec_tag_id(env, tag.into_struct_value()); let tag_id = nonrec_tag_id(env, tag.into_struct_value());
let mut cases = Vec::with_capacity_in(tags.len(), env.arena); let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
@ -449,8 +450,7 @@ fn hash_tag<'a, 'ctx, 'env>(
env.builder.build_switch(tag_id, default, &cases); env.builder.build_switch(tag_id, default, &cases);
} }
Recursive(tags) => { Recursive(tags) => {
// SAFETY recursive tag unions are not NULL let tag_id = get_tag_id(env, parent, union_layout, tag).into_int_value();
let tag_id = unsafe { rec_tag_id_unsafe(env, tag.into_pointer_value()) };
let mut cases = Vec::with_capacity_in(tags.len(), env.arena); let mut cases = Vec::with_capacity_in(tags.len(), env.arena);

View file

@ -34,12 +34,15 @@ pub fn basic_type_from_layout<'a, 'ctx, 'env>(
Union(variant) => { Union(variant) => {
use UnionLayout::*; use UnionLayout::*;
match variant { match variant {
Recursive(tags) NullableWrapped {
| NullableWrapped {
other_tags: tags, .. other_tags: tags, ..
} => { } => {
let block = block_of_memory_slices(env.context, tags, env.ptr_bytes); let data = block_of_memory_slices(env.context, tags, env.ptr_bytes);
block.ptr_type(AddressSpace::Generic).into()
env.context
.struct_type(&[data, env.context.i64_type().into()], false)
.ptr_type(AddressSpace::Generic)
.into()
} }
NullableUnwrapped { other_fields, .. } => { NullableUnwrapped { other_fields, .. } => {
let block = let block =
@ -50,8 +53,16 @@ pub fn basic_type_from_layout<'a, 'ctx, 'env>(
let block = block_of_memory_slices(env.context, &[fields], env.ptr_bytes); let block = block_of_memory_slices(env.context, &[fields], env.ptr_bytes);
block.ptr_type(AddressSpace::Generic).into() block.ptr_type(AddressSpace::Generic).into()
} }
NonRecursive(_) => { Recursive(tags) => {
let data = block_of_memory(env.context, layout, env.ptr_bytes); let data = block_of_memory_slices(env.context, tags, env.ptr_bytes);
env.context
.struct_type(&[data, env.context.i64_type().into()], false)
.ptr_type(AddressSpace::Generic)
.into()
}
NonRecursive(tags) => {
let data = block_of_memory_slices(env.context, tags, env.ptr_bytes);
env.context env.context
.struct_type(&[data, env.context.i64_type().into()], false) .struct_type(&[data, env.context.i64_type().into()], false)
@ -117,6 +128,32 @@ pub fn block_of_memory_slices<'ctx>(
block_of_memory_help(context, union_size) block_of_memory_help(context, union_size)
} }
pub fn union_data_is_struct<'a, 'ctx, 'env>(
env: &crate::llvm::build::Env<'a, 'ctx, 'env>,
layouts: &[Layout<'_>],
) -> StructType<'ctx> {
let data_type = basic_type_from_record(env, layouts);
union_data_is_struct_type(env.context, data_type.into_struct_type())
}
pub fn union_data_is_struct_type<'ctx>(
context: &'ctx Context,
struct_type: StructType<'ctx>,
) -> StructType<'ctx> {
let tag_id_type = context.i64_type();
context.struct_type(&[struct_type.into(), tag_id_type.into()], false)
}
pub fn union_data_block_of_memory<'ctx>(
context: &'ctx Context,
layouts: &[&[Layout<'_>]],
ptr_bytes: u32,
) -> StructType<'ctx> {
let tag_id_type = context.i64_type();
let data_type = block_of_memory_slices(context, layouts, ptr_bytes);
context.struct_type(&[data_type, tag_id_type.into()], false)
}
pub fn block_of_memory<'ctx>( pub fn block_of_memory<'ctx>(
context: &'ctx Context, context: &'ctx Context,
layout: &Layout<'_>, layout: &Layout<'_>,

View file

@ -4,7 +4,9 @@ use crate::llvm::build::{
LLVM_SADD_WITH_OVERFLOW_I64, TAG_DATA_INDEX, TAG_ID_INDEX, LLVM_SADD_WITH_OVERFLOW_I64, TAG_DATA_INDEX, TAG_ID_INDEX,
}; };
use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list}; use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list};
use crate::llvm::convert::{basic_type_from_layout, block_of_memory_slices, ptr_int}; use crate::llvm::convert::{
basic_type_from_layout, block_of_memory_slices, ptr_int, union_data_block_of_memory,
};
use bumpalo::collections::Vec; use bumpalo::collections::Vec;
use inkwell::basic_block::BasicBlock; use inkwell::basic_block::BasicBlock;
use inkwell::context::Context; use inkwell::context::Context;
@ -651,6 +653,7 @@ fn modify_refcount_layout_build_function<'a, 'ctx, 'env>(
layout_ids, layout_ids,
mode, mode,
&WhenRecursive::Loop(*variant), &WhenRecursive::Loop(*variant),
*variant,
tags, tags,
true, true,
); );
@ -666,6 +669,7 @@ fn modify_refcount_layout_build_function<'a, 'ctx, 'env>(
layout_ids, layout_ids,
mode, mode,
&WhenRecursive::Loop(*variant), &WhenRecursive::Loop(*variant),
*variant,
&*env.arena.alloc([other_fields]), &*env.arena.alloc([other_fields]),
true, true,
); );
@ -679,6 +683,7 @@ fn modify_refcount_layout_build_function<'a, 'ctx, 'env>(
layout_ids, layout_ids,
mode, mode,
&WhenRecursive::Loop(*variant), &WhenRecursive::Loop(*variant),
*variant,
&*env.arena.alloc([*fields]), &*env.arena.alloc([*fields]),
true, true,
); );
@ -691,6 +696,7 @@ fn modify_refcount_layout_build_function<'a, 'ctx, 'env>(
layout_ids, layout_ids,
mode, mode,
&WhenRecursive::Loop(*variant), &WhenRecursive::Loop(*variant),
*variant,
tags, tags,
false, false,
); );
@ -1203,10 +1209,11 @@ fn build_rec_union<'a, 'ctx, 'env>(
layout_ids: &mut LayoutIds<'a>, layout_ids: &mut LayoutIds<'a>,
mode: Mode, mode: Mode,
when_recursive: &WhenRecursive<'a>, when_recursive: &WhenRecursive<'a>,
fields: &'a [&'a [Layout<'a>]], union_layout: UnionLayout<'a>,
tags: &'a [&'a [Layout<'a>]],
is_nullable: bool, is_nullable: bool,
) -> FunctionValue<'ctx> { ) -> FunctionValue<'ctx> {
let layout = Layout::Union(UnionLayout::Recursive(fields)); let layout = Layout::Union(UnionLayout::Recursive(tags));
let (_, fn_name) = function_name_from_mode( let (_, fn_name) = function_name_from_mode(
layout_ids, layout_ids,
@ -1223,9 +1230,7 @@ fn build_rec_union<'a, 'ctx, 'env>(
let block = env.builder.get_insert_block().expect("to be in a function"); let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap(); let di_location = env.builder.get_current_debug_location().unwrap();
let basic_type = block_of_memory_slices(env.context, fields, env.ptr_bytes) let basic_type = basic_type_from_layout(env, &Layout::Union(union_layout));
.ptr_type(AddressSpace::Generic)
.into();
let function_value = build_header(env, basic_type, mode, &fn_name); let function_value = build_header(env, basic_type, mode, &fn_name);
build_rec_union_help( build_rec_union_help(
@ -1233,7 +1238,8 @@ fn build_rec_union<'a, 'ctx, 'env>(
layout_ids, layout_ids,
mode, mode,
when_recursive, when_recursive,
fields, union_layout,
tags,
function_value, function_value,
is_nullable, is_nullable,
); );
@ -1249,11 +1255,13 @@ fn build_rec_union<'a, 'ctx, 'env>(
function function
} }
#[allow(clippy::too_many_arguments)]
fn build_rec_union_help<'a, 'ctx, 'env>( fn build_rec_union_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>, layout_ids: &mut LayoutIds<'a>,
mode: Mode, mode: Mode,
when_recursive: &WhenRecursive<'a>, when_recursive: &WhenRecursive<'a>,
union_layout: UnionLayout<'a>,
tags: &'a [&'a [roc_mono::layout::Layout<'a>]], tags: &'a [&'a [roc_mono::layout::Layout<'a>]],
fn_val: FunctionValue<'ctx>, fn_val: FunctionValue<'ctx>,
is_nullable: bool, is_nullable: bool,
@ -1279,8 +1287,6 @@ fn build_rec_union_help<'a, 'ctx, 'env>(
let parent = fn_val; let parent = fn_val;
let layout = Layout::Union(UnionLayout::Recursive(tags));
debug_assert!(arg_val.is_pointer_value()); debug_assert!(arg_val.is_pointer_value());
let value_ptr = arg_val.into_pointer_value(); let value_ptr = arg_val.into_pointer_value();
@ -1309,6 +1315,8 @@ fn build_rec_union_help<'a, 'ctx, 'env>(
env.builder.position_at_end(should_recurse_block); env.builder.position_at_end(should_recurse_block);
let layout = Layout::Union(union_layout);
match mode { match mode {
Mode::Inc => { Mode::Inc => {
// inc is cheap; we never recurse // inc is cheap; we never recurse
@ -1342,7 +1350,7 @@ fn build_rec_union_help<'a, 'ctx, 'env>(
when_recursive, when_recursive,
parent, parent,
fn_val, fn_val,
layout, union_layout,
tags, tags,
value_ptr, value_ptr,
refcount_ptr, refcount_ptr,
@ -1360,7 +1368,7 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>(
when_recursive: &WhenRecursive<'a>, when_recursive: &WhenRecursive<'a>,
parent: FunctionValue<'ctx>, parent: FunctionValue<'ctx>,
decrement_fn: FunctionValue<'ctx>, decrement_fn: FunctionValue<'ctx>,
layout: Layout<'a>, union_layout: UnionLayout<'a>,
tags: &[&[Layout<'a>]], tags: &[&[Layout<'a>]],
value_ptr: PointerValue<'ctx>, value_ptr: PointerValue<'ctx>,
refcount_ptr: PointerToRefcount<'ctx>, refcount_ptr: PointerToRefcount<'ctx>,
@ -1433,7 +1441,15 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>(
debug_assert!(ptr_as_i64_ptr.is_pointer_value()); debug_assert!(ptr_as_i64_ptr.is_pointer_value());
// therefore we must cast it to our desired type // therefore we must cast it to our desired type
let union_type = block_of_memory_slices(env.context, tags, env.ptr_bytes);
let union_type = match union_layout {
UnionLayout::Recursive(_) | UnionLayout::NullableWrapped { .. } => {
union_data_block_of_memory(env.context, tags, env.ptr_bytes).into()
}
UnionLayout::NonRecursive(_) => unreachable!(),
_ => block_of_memory_slices(env.context, tags, env.ptr_bytes),
};
let recursive_field_ptr = cast_basic_basic( let recursive_field_ptr = cast_basic_basic(
env.builder, env.builder,
ptr_as_i64_ptr, ptr_as_i64_ptr,
@ -1461,7 +1477,7 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>(
// lists. To achieve it, we must first load all fields that we want to inc/dec (done above) // lists. To achieve it, we must first load all fields that we want to inc/dec (done above)
// and store them on the stack, then modify (and potentially free) the current cell, then // and store them on the stack, then modify (and potentially free) the current cell, then
// actually inc/dec the fields. // actually inc/dec the fields.
refcount_ptr.modify(call_mode, &layout, env); refcount_ptr.modify(call_mode, &Layout::Union(union_layout), env);
for (field, field_layout) in deferred_nonrec { for (field, field_layout) in deferred_nonrec {
modify_refcount_layout_help( modify_refcount_layout_help(
@ -1514,7 +1530,7 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>(
env.builder.position_at_end(merge_block); env.builder.position_at_end(merge_block);
// increment/decrement the cons-cell itself // increment/decrement the cons-cell itself
refcount_ptr.modify(call_mode, &layout, env); refcount_ptr.modify(call_mode, &Layout::Union(union_layout), env);
// this function returns void // this function returns void
builder.build_return(None); builder.build_return(None);

View file

@ -1683,11 +1683,11 @@ fn binary_tree_double_pattern_match() {
foo = \btree -> foo = \btree ->
when btree is when btree is
Node (Node (Leaf x) _) _ -> x Node (Node (Leaf x) _) _ -> x
_ -> 0 _ -> 1
main : I64 main : I64
main = main =
foo (Node (Node (Leaf 32) (Leaf 0)) (Leaf 0)) foo (Node (Node (Leaf 32) (Leaf 2)) (Leaf 3))
"# "#
), ),
32, 32,