make bitcast more descriptive

This commit is contained in:
Folkert 2021-01-18 22:06:14 +01:00
parent 31bf658b20
commit cb0c5ef751
8 changed files with 447 additions and 141 deletions

View file

@ -952,11 +952,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
let internal_type = let internal_type =
basic_type_from_layout(env.arena, env.context, &tag_layout, env.ptr_bytes); basic_type_from_layout(env.arena, env.context, &tag_layout, env.ptr_bytes);
cast_basic_basic( cast_tag_to_block_of_memory(builder, struct_val.into_struct_value(), internal_type)
builder,
struct_val.into_struct_value().into(),
internal_type,
)
} }
Tag { Tag {
arguments, arguments,
@ -1001,10 +997,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(val.is_pointer_value()); debug_assert!(val.is_pointer_value());
// we store recursive pointers as `i64*` // we store recursive pointers as `i64*`
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
val, val,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ctx.i64_type().ptr_type(AddressSpace::Generic),
"cast_recursive_pointer",
); );
field_vals.push(ptr); field_vals.push(ptr);
@ -1020,12 +1016,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
// Create the struct_type // Create the struct_type
let data_ptr = reserve_with_refcount(env, &tag_layout); let data_ptr = reserve_with_refcount(env, &tag_layout);
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false);
let struct_ptr = cast_basic_basic( let struct_ptr = env
builder, .builder
data_ptr.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), data_ptr,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "block_of_memory_to_tag",
)
.into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
@ -1098,10 +1096,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(val.is_pointer_value()); debug_assert!(val.is_pointer_value());
// we store recursive pointers as `i64*` // we store recursive pointers as `i64*`
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
val, val,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ctx.i64_type().ptr_type(AddressSpace::Generic),
"cast_recursive_pointer",
); );
field_vals.push(ptr); field_vals.push(ptr);
@ -1117,12 +1115,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
// Create the struct_type // Create the struct_type
let data_ptr = reserve_with_refcount(env, &tag_layout); let data_ptr = reserve_with_refcount(env, &tag_layout);
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false);
let struct_ptr = cast_basic_basic( let struct_ptr = env
builder, .builder
data_ptr.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), data_ptr,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "block_of_memory_to_tag",
)
.into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
@ -1197,10 +1197,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(val.is_pointer_value()); debug_assert!(val.is_pointer_value());
// we store recursive pointers as `i64*` // we store recursive pointers as `i64*`
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
val, val,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ctx.i64_type().ptr_type(AddressSpace::Generic),
"cast_recursive_pointer",
); );
field_vals.push(ptr); field_vals.push(ptr);
@ -1220,12 +1220,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
); );
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false);
let struct_ptr = cast_basic_basic( let struct_ptr = env
builder, .builder
data_ptr.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), data_ptr,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "block_of_memory_to_tag",
)
.into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
@ -1331,7 +1333,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
.context .context
.struct_type(field_types.into_bump_slice(), false); .struct_type(field_types.into_bump_slice(), false);
let struct_value = cast_struct_struct(builder, value, struct_type); let struct_value = access_index_struct_value(builder, value, struct_type);
let result = builder let result = builder
.build_extract_value(struct_value, *index as u32, "") .build_extract_value(struct_value, *index as u32, "")
@ -1342,11 +1344,12 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
block_of_memory(env.context, &struct_layout, env.ptr_bytes); block_of_memory(env.context, &struct_layout, env.ptr_bytes);
// the value is a pointer to the actual value; load that value! // the value is a pointer to the actual value; load that value!
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
result, result,
desired_type.ptr_type(AddressSpace::Generic).into(), desired_type.ptr_type(AddressSpace::Generic),
"cast_struct_value_pointer",
); );
builder.build_load(ptr.into_pointer_value(), "load_recursive_field") builder.build_load(ptr.into_pointer_value(), "load_recursive_field")
} else { } else {
result result
@ -1494,12 +1497,14 @@ fn lookup_at_index_ptr<'a, 'ctx, 'env>(
use inkwell::types::BasicType; use inkwell::types::BasicType;
let builder = env.builder; let builder = env.builder;
let ptr = cast_basic_basic( let ptr = env
builder, .builder
value.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), value,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "cast_lookup_at_index_ptr",
)
.into_pointer_value();
let elem_ptr = builder let elem_ptr = builder
.build_struct_gep(ptr, index as u32, "at_index_struct_gep") .build_struct_gep(ptr, index as u32, "at_index_struct_gep")
@ -1510,12 +1515,11 @@ fn lookup_at_index_ptr<'a, 'ctx, 'env>(
if let Some(Layout::RecursivePointer) = field_layouts.get(index as usize) { if let Some(Layout::RecursivePointer) = field_layouts.get(index as usize) {
// a recursive field is stored as a `i64*`, to use it we must cast it to // a recursive field is stored as a `i64*`, to use it we must cast it to
// a pointer to the block of memory representation // a pointer to the block of memory representation
cast_basic_basic( builder.build_bitcast(
builder,
result, result,
block_of_memory(env.context, &struct_layout, env.ptr_bytes) block_of_memory(env.context, &struct_layout, env.ptr_bytes)
.ptr_type(AddressSpace::Generic) .ptr_type(AddressSpace::Generic),
.into(), "cast_rec_pointer_lookup_at_index_ptr",
) )
} else { } else {
result result
@ -1583,12 +1587,13 @@ pub fn allocate_with_refcount_help<'a, 'ctx, 'env>(
// We must return a pointer to the first element: // We must return a pointer to the first element:
let data_ptr = { let data_ptr = {
let int_type = ptr_int(ctx, env.ptr_bytes); let int_type = ptr_int(ctx, env.ptr_bytes);
let as_usize_ptr = cast_basic_basic( let as_usize_ptr = builder
env.builder, .build_bitcast(
ptr.into(), ptr,
int_type.ptr_type(AddressSpace::Generic).into(), int_type.ptr_type(AddressSpace::Generic),
) "to_usize_ptr",
.into_pointer_value(); )
.into_pointer_value();
let index = match extra_bytes { let index = match extra_bytes {
n if n == env.ptr_bytes => 1, n if n == env.ptr_bytes => 1,
@ -1601,14 +1606,17 @@ pub fn allocate_with_refcount_help<'a, 'ctx, 'env>(
let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic); let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic);
unsafe { unsafe {
cast_basic_basic( builder
env.builder, .build_bitcast(
env.builder env.builder.build_in_bounds_gep(
.build_in_bounds_gep(as_usize_ptr, &[index_intvalue], "get_data_ptr") as_usize_ptr,
.into(), &[index_intvalue],
ptr_type.into(), "get_data_ptr",
) ),
.into_pointer_value() ptr_type,
"malloc_cast_to_desired",
)
.into_pointer_value()
} }
}; };
@ -1664,7 +1672,7 @@ fn list_literal<'a, 'ctx, 'env>(
let ptr_bytes = env.ptr_bytes; let ptr_bytes = env.ptr_bytes;
let u8_ptr_type = ctx.i8_type().ptr_type(AddressSpace::Generic); let u8_ptr_type = ctx.i8_type().ptr_type(AddressSpace::Generic);
let generic_ptr = cast_basic_basic(builder, ptr.into(), u8_ptr_type.into()); let generic_ptr = builder.build_bitcast(ptr, u8_ptr_type, "to_generic_ptr");
let struct_type = collection(ctx, ptr_bytes); let struct_type = collection(ctx, ptr_bytes);
let len = BasicValueEnum::IntValue(env.ptr_int().const_int(len_u64, false)); let len = BasicValueEnum::IntValue(env.ptr_int().const_int(len_u64, false));
@ -2104,14 +2112,18 @@ pub fn load_symbol_and_layout<'a, 'ctx, 'env, 'b>(
None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope), None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope),
} }
} }
fn access_index_struct_value<'ctx>(
/// Cast a struct to another struct of the same (or smaller?) size
pub fn cast_struct_struct<'ctx>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
from_value: StructValue<'ctx>, from_value: StructValue<'ctx>,
to_type: StructType<'ctx>, to_type: StructType<'ctx>,
) -> StructValue<'ctx> { ) -> StructValue<'ctx> {
cast_basic_basic(builder, from_value.into(), to_type.into()).into_struct_value() complex_bitcast(
builder,
from_value.into(),
to_type.into(),
"access_index_struct_value",
)
.into_struct_value()
} }
/// Cast a value to another value of the same (or smaller?) size /// Cast a value to another value of the same (or smaller?) size
@ -2119,6 +2131,52 @@ pub fn cast_basic_basic<'ctx>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
from_value: BasicValueEnum<'ctx>, from_value: BasicValueEnum<'ctx>,
to_type: BasicTypeEnum<'ctx>, to_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
complex_bitcast(builder, from_value, to_type, "cast_basic_basic")
}
pub fn complex_bitcast_struct_struct<'ctx>(
builder: &Builder<'ctx>,
from_value: StructValue<'ctx>,
to_type: StructType<'ctx>,
name: &str,
) -> StructValue<'ctx> {
complex_bitcast(builder, from_value.into(), to_type.into(), name).into_struct_value()
}
fn cast_tag_to_block_of_memory<'ctx>(
builder: &Builder<'ctx>,
from_value: StructValue<'ctx>,
to_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
complex_bitcast(
builder,
from_value.into(),
to_type,
"tag_to_block_of_memory",
)
}
pub fn cast_block_of_memory_to_tag<'ctx>(
builder: &Builder<'ctx>,
from_value: StructValue<'ctx>,
to_type: BasicTypeEnum<'ctx>,
) -> StructValue<'ctx> {
complex_bitcast(
builder,
from_value.into(),
to_type,
"block_of_memory_to_tag",
)
.into_struct_value()
}
/// Cast a value to another value of the same (or smaller?) size
pub fn complex_bitcast<'ctx>(
builder: &Builder<'ctx>,
from_value: BasicValueEnum<'ctx>,
to_type: BasicTypeEnum<'ctx>,
name: &str,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
use inkwell::types::BasicType; use inkwell::types::BasicType;
@ -2135,7 +2193,7 @@ pub fn cast_basic_basic<'ctx>(
.build_bitcast( .build_bitcast(
argument_pointer, argument_pointer,
to_type.ptr_type(inkwell::AddressSpace::Generic), to_type.ptr_type(inkwell::AddressSpace::Generic),
"cast_basic_basic", name,
) )
.into_pointer_value(); .into_pointer_value();
@ -2150,7 +2208,12 @@ fn extract_tag_discriminant_struct<'a, 'ctx, 'env>(
.context .context
.struct_type(&[env.context.i64_type().into()], false); .struct_type(&[env.context.i64_type().into()], false);
let struct_value = cast_struct_struct(env.builder, from_value, struct_type); let struct_value = complex_bitcast_struct_struct(
env.builder,
from_value,
struct_type,
"extract_tag_discriminant_struct",
);
env.builder env.builder
.build_extract_value(struct_value, 0, "") .build_extract_value(struct_value, 0, "")
@ -2219,6 +2282,8 @@ fn build_switch_ir<'a, 'ctx, 'env>(
let scope = &mut copy; let scope = &mut copy;
let cond_symbol = &cond_symbol; let cond_symbol = &cond_symbol;
let (cond_value, stored_layout) = load_symbol_and_layout(env, scope, cond_symbol);
debug_assert_eq!(&cond_layout, stored_layout);
let cont_block = context.append_basic_block(parent, "cont"); let cont_block = context.append_basic_block(parent, "cont");
@ -2227,19 +2292,17 @@ fn build_switch_ir<'a, 'ctx, 'env>(
Layout::Builtin(Builtin::Float64) => { Layout::Builtin(Builtin::Float64) => {
// float matches are done on the bit pattern // float matches are done on the bit pattern
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
let full_cond = load_symbol(env, scope, cond_symbol);
builder builder
.build_bitcast(full_cond, env.context.i64_type(), "") .build_bitcast(cond_value, env.context.i64_type(), "")
.into_int_value() .into_int_value()
} }
Layout::Builtin(Builtin::Float32) => { Layout::Builtin(Builtin::Float32) => {
// float matches are done on the bit pattern // float matches are done on the bit pattern
cond_layout = Layout::Builtin(Builtin::Int32); cond_layout = Layout::Builtin(Builtin::Int32);
let full_cond = load_symbol(env, scope, cond_symbol);
builder builder
.build_bitcast(full_cond, env.context.i32_type(), "") .build_bitcast(cond_value, env.context.i32_type(), "")
.into_int_value() .into_int_value()
} }
Layout::Union(variant) => { Layout::Union(variant) => {
@ -2249,7 +2312,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
NonRecursive(_) => { NonRecursive(_) => {
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
let full_cond = load_symbol(env, scope, cond_symbol).into_struct_value(); let full_cond = cond_value.into_struct_value();
extract_tag_discriminant_struct(env, full_cond) extract_tag_discriminant_struct(env, full_cond)
} }
@ -2257,21 +2320,13 @@ fn build_switch_ir<'a, 'ctx, 'env>(
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
use BasicValueEnum::*; debug_assert!(cond_value.is_pointer_value());
match load_symbol(env, scope, cond_symbol) { extract_tag_discriminant_ptr(env, cond_value.into_pointer_value())
PointerValue(full_cond_ptr) => {
extract_tag_discriminant_ptr(env, full_cond_ptr)
}
StructValue(full_cond_struct) => {
extract_tag_discriminant_struct(env, full_cond_struct)
}
_ => unreachable!(),
}
} }
NullableWrapped { nullable_id, .. } => { NullableWrapped { nullable_id, .. } => {
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
let full_cond_ptr = load_symbol(env, scope, cond_symbol).into_pointer_value(); let full_cond_ptr = cond_value.into_pointer_value();
let comparison: IntValue = let comparison: IntValue =
env.builder.build_is_null(full_cond_ptr, "is_null_cond"); env.builder.build_is_null(full_cond_ptr, "is_null_cond");
@ -2302,7 +2357,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
} }
} }
} }
Layout::Builtin(_) => load_symbol(env, scope, cond_symbol).into_int_value(), Layout::Builtin(_) => cond_value.into_int_value(),
other => todo!("Build switch value from layout: {:?}", other), other => todo!("Build switch value from layout: {:?}", other),
}; };

View file

@ -1,5 +1,5 @@
use crate::llvm::build::{ use crate::llvm::build::{
cast_basic_basic, cast_struct_struct, create_entry_block_alloca, set_name, Env, Scope, cast_basic_basic, cast_block_of_memory_to_tag, create_entry_block_alloca, set_name, Env, Scope,
FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64, FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64,
}; };
use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list}; use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list};
@ -45,12 +45,14 @@ impl<'ctx> PointerToRefcount<'ctx> {
// must make sure it's a pointer to usize // must make sure it's a pointer to usize
let refcount_type = ptr_int(env.context, env.ptr_bytes); let refcount_type = ptr_int(env.context, env.ptr_bytes);
let value = cast_basic_basic( let value = env
env.builder, .builder
ptr.into(), .build_bitcast(
refcount_type.ptr_type(AddressSpace::Generic).into(), ptr,
) refcount_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "to_refcount_ptr",
)
.into_pointer_value();
Self { value } Self { value }
} }
@ -64,7 +66,8 @@ impl<'ctx> PointerToRefcount<'ctx> {
let refcount_type = ptr_int(env.context, env.ptr_bytes); let refcount_type = ptr_int(env.context, env.ptr_bytes);
let refcount_ptr_type = refcount_type.ptr_type(AddressSpace::Generic); let refcount_ptr_type = refcount_type.ptr_type(AddressSpace::Generic);
let ptr_as_usize_ptr = cast_basic_basic(builder, data_ptr.into(), refcount_ptr_type.into()) let ptr_as_usize_ptr = builder
.build_bitcast(data_ptr, refcount_ptr_type, "as_usize_ptr")
.into_pointer_value(); .into_pointer_value();
// get a pointer to index -1 // get a pointer to index -1
@ -1232,12 +1235,14 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>(
); );
// cast the opaque pointer to a pointer of the correct shape // cast the opaque pointer to a pointer of the correct shape
let struct_ptr = cast_basic_basic( let struct_ptr = env
env.builder, .builder
value_ptr.into(), .build_bitcast(
wrapper_type.ptr_type(AddressSpace::Generic).into(), value_ptr,
) wrapper_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "opaque_to_correct",
)
.into_pointer_value();
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1428,8 +1433,8 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
env.ptr_bytes, env.ptr_bytes,
); );
let wrapper_struct = debug_assert!(wrapper_type.is_struct_type());
cast_struct_struct(env.builder, wrapper_struct, wrapper_type.into_struct_type()); let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type);
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1528,15 +1533,11 @@ fn rec_union_read_tag<'a, 'ctx, 'env>(
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
// Assumption: the tag is the first thing stored // Assumption: the tag is the first thing stored
// so cast the pointer to the data to a `i64*` // so cast the pointer to the data to a `i64*`
let tag_ptr = cast_basic_basic( let tag_ptr_type = env.context.i64_type().ptr_type(AddressSpace::Generic);
env.builder, let tag_ptr = env
value_ptr.into(), .builder
env.context .build_bitcast(value_ptr, tag_ptr_type, "cast_tag_ptr")
.i64_type() .into_pointer_value();
.ptr_type(AddressSpace::Generic)
.into(),
)
.into_pointer_value();
env.builder env.builder
.build_load(tag_ptr, "load_tag_id") .build_load(tag_ptr, "load_tag_id")
@ -1634,12 +1635,14 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>(
); );
// cast the opaque pointer to a pointer of the correct shape // cast the opaque pointer to a pointer of the correct shape
let struct_ptr = cast_basic_basic( let struct_ptr = env
env.builder, .builder
value_ptr.into(), .build_bitcast(
wrapper_type.ptr_type(AddressSpace::Generic).into(), value_ptr,
) wrapper_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "opaque_to_correct",
)
.into_pointer_value();
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1657,10 +1660,10 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>(
// therefore we must cast it to our desired type // therefore we must cast it to our desired type
let union_type = block_of_memory_slices(env.context, tags, env.ptr_bytes); let union_type = block_of_memory_slices(env.context, tags, env.ptr_bytes);
let recursive_field_ptr = cast_basic_basic( let recursive_field_ptr = env.builder.build_bitcast(
env.builder,
ptr_as_i64_ptr, ptr_as_i64_ptr,
union_type.ptr_type(AddressSpace::Generic).into(), union_type.ptr_type(AddressSpace::Generic),
"recursive_to_desired",
); );
// recursively increment the field // recursively increment the field
@ -1694,10 +1697,11 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>(
// read the tag_id // read the tag_id
let tag_id = rec_union_read_tag(env, value_ptr); let tag_id = rec_union_read_tag(env, value_ptr);
let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); let tag_id_u8 = env
.builder
.build_int_cast(tag_id, env.context.i8_type(), "tag_id_u8");
env.builder env.builder.build_switch(tag_id_u8, merge_block, &cases);
.build_switch(tag_id_u8.into_int_value(), merge_block, &cases);
env.builder.position_at_end(merge_block); env.builder.position_at_end(merge_block);
@ -1808,7 +1812,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
.into_int_value() .into_int_value()
}; };
let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); let tag_id_u8 = env
.builder
.build_int_cast(tag_id, env.context.i8_type(), "tag_id_u8");
// next, make a jump table for all possible values of the tag_id // next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena); let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
@ -1834,8 +1840,8 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
env.ptr_bytes, env.ptr_bytes,
); );
let wrapper_struct = debug_assert!(wrapper_type.is_struct_type());
cast_struct_struct(env.builder, wrapper_struct, wrapper_type.into_struct_type()); let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type);
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1849,12 +1855,14 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
// therefore we must cast it to our desired type // therefore we must cast it to our desired type
let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); let union_type = block_of_memory(env.context, &layout, env.ptr_bytes);
let recursive_field_ptr = cast_basic_basic( let recursive_field_ptr = env
env.builder, .builder
ptr_as_i64_ptr, .build_bitcast(
union_type.ptr_type(AddressSpace::Generic).into(), ptr_as_i64_ptr,
) union_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "recursive_to_desired",
)
.into_pointer_value();
let recursive_field = env let recursive_field = env
.builder .builder
@ -1889,8 +1897,7 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
env.builder.position_at_end(before_block); env.builder.position_at_end(before_block);
env.builder env.builder.build_switch(tag_id_u8, merge_block, &cases);
.build_switch(tag_id_u8.into_int_value(), merge_block, &cases);
env.builder.position_at_end(merge_block); env.builder.position_at_end(merge_block);

View file

@ -2074,4 +2074,46 @@ mod gen_primitives {
i64 i64
); );
} }
#[test]
fn bug_exposer() {
// the decision tree will generate a jump to the `1` branch here
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
Expr : [ ZAdd Expr Expr, Val I64, Var I64 ]
eval : Expr -> I64
eval = \e ->
when e is
Var _ -> 0
Val v -> v
ZAdd l r -> eval l + eval r
constFolding : Expr -> Expr
constFolding = \e ->
when e is
ZAdd e1 e2 ->
when Pair e1 e2 is
Pair (Val a) (Val b) -> Val (a+b)
Pair (Val a) (ZAdd x (Val b)) -> ZAdd (Val (a+b)) x
Pair _ _ -> ZAdd e1 e2
_ -> e
expr : Expr
expr = ZAdd (Val 3) (ZAdd (Val 4) (Val 5))
main : I64
main = eval (constFolding expr)
"#
),
12,
i64
);
}
} }

View file

@ -538,4 +538,34 @@ mod gen_str {
debug_assert_eq!(short.clone(), short); debug_assert_eq!(short.clone(), short);
debug_assert_eq!(empty.clone(), empty); debug_assert_eq!(empty.clone(), empty);
} }
#[test]
fn nested_recursive_literal() {
assert_evals_to!(
indoc!(
r#"
Expr : [ Add Expr Expr, Val I64, Var I64 ]
expr : Expr
expr = Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1))
printExpr : Expr -> Str
printExpr = \e ->
when e is
Add a b ->
"Add ("
|> Str.concat (printExpr a)
|> Str.concat ") ("
|> Str.concat (printExpr b)
|> Str.concat ")"
Val v -> "Val " |> Str.concat (Str.fromInt v)
Var v -> "Var " |> Str.concat (Str.fromInt v)
printExpr expr
"#
),
"Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1))",
&'static str
);
}
} }

View file

@ -941,4 +941,23 @@ mod gen_tags {
i64 i64
); );
} }
#[test]
fn nested_recursive_literal() {
assert_evals_to!(
indoc!(
r"#
Expr : [ Add Expr Expr, Val I64, Var I64 ]
e : Expr
e = Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1))
e
#"
),
0,
&i64,
|x: &i64| *x
);
}
} }

View file

99
examples/task/CFold.roc Normal file
View file

@ -0,0 +1,99 @@
app "cfold"
packages { base: "thing/platform-dir" }
imports [base.Task]
provides [ main ] to base
# adapted from https://github.com/koka-lang/koka/blob/master/test/bench/haskell/cfold.hs
main : Task.Task {} []
main =
e = mkExpr 3 1
unoptimized = eval e
optimized = eval (constFolding (reassoc e))
unoptimized
|> Str.fromInt
|> Str.concat " & "
|> Str.concat (Str.fromInt optimized)
|> Task.putLine
Expr : [
Add Expr Expr,
Mul Expr Expr,
Val I64,
Var I64
]
mkExpr : I64, I64 -> Expr
mkExpr = \n , v ->
when n is
0 -> if v == 0 then Var 1 else Val v
_ -> Add (mkExpr (n-1) (v+1)) (mkExpr (n-1) (max (v-1) 0))
max : I64, I64 -> I64
max = \a, b -> if a > b then a else b
appendAdd : Expr, Expr -> Expr
appendAdd = \e1, e2 ->
when e1 is
Add a1 a2 -> Add a1 (appendAdd a2 e2)
_ -> Add e1 e2
appendMul : Expr, Expr -> Expr
appendMul = \e1, e2 ->
when e1 is
Mul a1 a2 -> Mul a1 (appendMul a2 e2)
_ -> Mul e1 e2
eval : Expr -> I64
eval = \e ->
when e is
Var _ -> 0
Val v -> v
Add l r -> eval l + eval r
Mul l r -> eval l * eval r
reassoc : Expr -> Expr
reassoc = \e ->
when e is
Add e1 e2 ->
x1 = reassoc e1
x2 = reassoc e2
appendAdd x1 x2
Mul e1 e2 ->
x1 = reassoc e1
x2 = reassoc e2
appendMul x1 x2
_ -> e
constFolding : Expr -> Expr
constFolding = \e ->
when e is
Add e1 e2 ->
x1 = constFolding e1
x2 = constFolding e2
when Pair x1 x2 is
Pair (Val a) (Val b) -> Val (a+b)
# Pair (Val a) (Add (Val b) x) -> Add (Val (a+b)) x
Pair (Val a) (Add x (Val b)) -> Add (Val (a+b)) x
Pair _ _ -> Add x1 x2
Mul e1 e2 ->
x1 = constFolding e1
x2 = constFolding e2
when Pair x1 x2 is
Pair (Val a) (Val b) -> Val (a*b)
Pair (Val a) (Mul (Val b) x) -> Mul (Val (a*b)) x
Pair (Val a) (Mul x (Val b)) -> Mul (Val (a*b)) x
Pair _ _ -> Mul x1 x2
_ -> e

54
examples/task/NQueens.roc Normal file
View file

@ -0,0 +1,54 @@
app "nqueens"
packages { base: "thing/platform-dir" }
imports [base.Task]
provides [ main ] to base
main : Task.Task {} []
main =
queens 10
|> Str.fromInt
|> Task.putLine
ConsList a : [ Nil, Cons a (ConsList a) ]
queens = \n -> length (findSolutions n n)
length : ConsList a -> I64
length = \xs -> lengthHelp xs 0
lengthHelp : ConsList a, I64 -> I64
lengthHelp = \xs, acc ->
when xs is
Nil -> acc
Cons _ rest -> lengthHelp rest (1 + acc)
safe : I64, I64, ConsList I64 -> Bool
safe = \queen, diagonal, xs ->
when xs is
Nil ->
True
Cons q t ->
queen != q && queen != q + diagonal && queen != q - diagonal && safe queen (diagonal + 1) t
appendSafe : I64, ConsList I64, ConsList (ConsList I64) -> ConsList (ConsList I64)
appendSafe = \k, soln, solns ->
if k <= 0 then
solns
else if safe k 1 soln then
appendSafe (k - 1) soln (Cons (Cons k soln) solns)
else
appendSafe (k - 1) soln solns
extend = \n, acc, solns ->
when solns is
Nil -> acc
Cons soln rest -> extend n (appendSafe n soln acc) rest
findSolutions = \n, k ->
if k == 0 then
Cons Nil Nil
else
extend n Nil (findSolutions n (k - 1))