make bitcast more descriptive

This commit is contained in:
Folkert 2021-01-18 22:06:14 +01:00
parent 31bf658b20
commit cb0c5ef751
8 changed files with 447 additions and 141 deletions

View file

@ -952,11 +952,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
let internal_type = let internal_type =
basic_type_from_layout(env.arena, env.context, &tag_layout, env.ptr_bytes); basic_type_from_layout(env.arena, env.context, &tag_layout, env.ptr_bytes);
cast_basic_basic( cast_tag_to_block_of_memory(builder, struct_val.into_struct_value(), internal_type)
builder,
struct_val.into_struct_value().into(),
internal_type,
)
} }
Tag { Tag {
arguments, arguments,
@ -1001,10 +997,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(val.is_pointer_value()); debug_assert!(val.is_pointer_value());
// we store recursive pointers as `i64*` // we store recursive pointers as `i64*`
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
val, val,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ctx.i64_type().ptr_type(AddressSpace::Generic),
"cast_recursive_pointer",
); );
field_vals.push(ptr); field_vals.push(ptr);
@ -1020,12 +1016,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
// Create the struct_type // Create the struct_type
let data_ptr = reserve_with_refcount(env, &tag_layout); let data_ptr = reserve_with_refcount(env, &tag_layout);
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false);
let struct_ptr = cast_basic_basic( let struct_ptr = env
builder, .builder
data_ptr.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), data_ptr,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "block_of_memory_to_tag",
)
.into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
@ -1098,10 +1096,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(val.is_pointer_value()); debug_assert!(val.is_pointer_value());
// we store recursive pointers as `i64*` // we store recursive pointers as `i64*`
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
val, val,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ctx.i64_type().ptr_type(AddressSpace::Generic),
"cast_recursive_pointer",
); );
field_vals.push(ptr); field_vals.push(ptr);
@ -1117,12 +1115,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
// Create the struct_type // Create the struct_type
let data_ptr = reserve_with_refcount(env, &tag_layout); let data_ptr = reserve_with_refcount(env, &tag_layout);
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false);
let struct_ptr = cast_basic_basic( let struct_ptr = env
builder, .builder
data_ptr.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), data_ptr,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "block_of_memory_to_tag",
)
.into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
@ -1197,10 +1197,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
debug_assert!(val.is_pointer_value()); debug_assert!(val.is_pointer_value());
// we store recursive pointers as `i64*` // we store recursive pointers as `i64*`
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
val, val,
ctx.i64_type().ptr_type(AddressSpace::Generic).into(), ctx.i64_type().ptr_type(AddressSpace::Generic),
"cast_recursive_pointer",
); );
field_vals.push(ptr); field_vals.push(ptr);
@ -1220,12 +1220,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
); );
let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false);
let struct_ptr = cast_basic_basic( let struct_ptr = env
builder, .builder
data_ptr.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), data_ptr,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "block_of_memory_to_tag",
)
.into_pointer_value();
// Insert field exprs into struct_val // Insert field exprs into struct_val
for (index, field_val) in field_vals.into_iter().enumerate() { for (index, field_val) in field_vals.into_iter().enumerate() {
@ -1331,7 +1333,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
.context .context
.struct_type(field_types.into_bump_slice(), false); .struct_type(field_types.into_bump_slice(), false);
let struct_value = cast_struct_struct(builder, value, struct_type); let struct_value = access_index_struct_value(builder, value, struct_type);
let result = builder let result = builder
.build_extract_value(struct_value, *index as u32, "") .build_extract_value(struct_value, *index as u32, "")
@ -1342,11 +1344,12 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
block_of_memory(env.context, &struct_layout, env.ptr_bytes); block_of_memory(env.context, &struct_layout, env.ptr_bytes);
// the value is a pointer to the actual value; load that value! // the value is a pointer to the actual value; load that value!
let ptr = cast_basic_basic( let ptr = env.builder.build_bitcast(
builder,
result, result,
desired_type.ptr_type(AddressSpace::Generic).into(), desired_type.ptr_type(AddressSpace::Generic),
"cast_struct_value_pointer",
); );
builder.build_load(ptr.into_pointer_value(), "load_recursive_field") builder.build_load(ptr.into_pointer_value(), "load_recursive_field")
} else { } else {
result result
@ -1494,12 +1497,14 @@ fn lookup_at_index_ptr<'a, 'ctx, 'env>(
use inkwell::types::BasicType; use inkwell::types::BasicType;
let builder = env.builder; let builder = env.builder;
let ptr = cast_basic_basic( let ptr = env
builder, .builder
value.into(), .build_bitcast(
struct_type.ptr_type(AddressSpace::Generic).into(), value,
) struct_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "cast_lookup_at_index_ptr",
)
.into_pointer_value();
let elem_ptr = builder let elem_ptr = builder
.build_struct_gep(ptr, index as u32, "at_index_struct_gep") .build_struct_gep(ptr, index as u32, "at_index_struct_gep")
@ -1510,12 +1515,11 @@ fn lookup_at_index_ptr<'a, 'ctx, 'env>(
if let Some(Layout::RecursivePointer) = field_layouts.get(index as usize) { if let Some(Layout::RecursivePointer) = field_layouts.get(index as usize) {
// a recursive field is stored as a `i64*`, to use it we must cast it to // a recursive field is stored as a `i64*`, to use it we must cast it to
// a pointer to the block of memory representation // a pointer to the block of memory representation
cast_basic_basic( builder.build_bitcast(
builder,
result, result,
block_of_memory(env.context, &struct_layout, env.ptr_bytes) block_of_memory(env.context, &struct_layout, env.ptr_bytes)
.ptr_type(AddressSpace::Generic) .ptr_type(AddressSpace::Generic),
.into(), "cast_rec_pointer_lookup_at_index_ptr",
) )
} else { } else {
result result
@ -1583,12 +1587,13 @@ pub fn allocate_with_refcount_help<'a, 'ctx, 'env>(
// We must return a pointer to the first element: // We must return a pointer to the first element:
let data_ptr = { let data_ptr = {
let int_type = ptr_int(ctx, env.ptr_bytes); let int_type = ptr_int(ctx, env.ptr_bytes);
let as_usize_ptr = cast_basic_basic( let as_usize_ptr = builder
env.builder, .build_bitcast(
ptr.into(), ptr,
int_type.ptr_type(AddressSpace::Generic).into(), int_type.ptr_type(AddressSpace::Generic),
) "to_usize_ptr",
.into_pointer_value(); )
.into_pointer_value();
let index = match extra_bytes { let index = match extra_bytes {
n if n == env.ptr_bytes => 1, n if n == env.ptr_bytes => 1,
@ -1601,14 +1606,17 @@ pub fn allocate_with_refcount_help<'a, 'ctx, 'env>(
let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic); let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic);
unsafe { unsafe {
cast_basic_basic( builder
env.builder, .build_bitcast(
env.builder env.builder.build_in_bounds_gep(
.build_in_bounds_gep(as_usize_ptr, &[index_intvalue], "get_data_ptr") as_usize_ptr,
.into(), &[index_intvalue],
ptr_type.into(), "get_data_ptr",
) ),
.into_pointer_value() ptr_type,
"malloc_cast_to_desired",
)
.into_pointer_value()
} }
}; };
@ -1664,7 +1672,7 @@ fn list_literal<'a, 'ctx, 'env>(
let ptr_bytes = env.ptr_bytes; let ptr_bytes = env.ptr_bytes;
let u8_ptr_type = ctx.i8_type().ptr_type(AddressSpace::Generic); let u8_ptr_type = ctx.i8_type().ptr_type(AddressSpace::Generic);
let generic_ptr = cast_basic_basic(builder, ptr.into(), u8_ptr_type.into()); let generic_ptr = builder.build_bitcast(ptr, u8_ptr_type, "to_generic_ptr");
let struct_type = collection(ctx, ptr_bytes); let struct_type = collection(ctx, ptr_bytes);
let len = BasicValueEnum::IntValue(env.ptr_int().const_int(len_u64, false)); let len = BasicValueEnum::IntValue(env.ptr_int().const_int(len_u64, false));
@ -2104,14 +2112,18 @@ pub fn load_symbol_and_layout<'a, 'ctx, 'env, 'b>(
None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope), None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope),
} }
} }
fn access_index_struct_value<'ctx>(
/// Cast a struct to another struct of the same (or smaller?) size
pub fn cast_struct_struct<'ctx>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
from_value: StructValue<'ctx>, from_value: StructValue<'ctx>,
to_type: StructType<'ctx>, to_type: StructType<'ctx>,
) -> StructValue<'ctx> { ) -> StructValue<'ctx> {
cast_basic_basic(builder, from_value.into(), to_type.into()).into_struct_value() complex_bitcast(
builder,
from_value.into(),
to_type.into(),
"access_index_struct_value",
)
.into_struct_value()
} }
/// Cast a value to another value of the same (or smaller?) size /// Cast a value to another value of the same (or smaller?) size
@ -2119,6 +2131,52 @@ pub fn cast_basic_basic<'ctx>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
from_value: BasicValueEnum<'ctx>, from_value: BasicValueEnum<'ctx>,
to_type: BasicTypeEnum<'ctx>, to_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
complex_bitcast(builder, from_value, to_type, "cast_basic_basic")
}
pub fn complex_bitcast_struct_struct<'ctx>(
builder: &Builder<'ctx>,
from_value: StructValue<'ctx>,
to_type: StructType<'ctx>,
name: &str,
) -> StructValue<'ctx> {
complex_bitcast(builder, from_value.into(), to_type.into(), name).into_struct_value()
}
fn cast_tag_to_block_of_memory<'ctx>(
builder: &Builder<'ctx>,
from_value: StructValue<'ctx>,
to_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
complex_bitcast(
builder,
from_value.into(),
to_type,
"tag_to_block_of_memory",
)
}
pub fn cast_block_of_memory_to_tag<'ctx>(
builder: &Builder<'ctx>,
from_value: StructValue<'ctx>,
to_type: BasicTypeEnum<'ctx>,
) -> StructValue<'ctx> {
complex_bitcast(
builder,
from_value.into(),
to_type,
"block_of_memory_to_tag",
)
.into_struct_value()
}
/// Cast a value to another value of the same (or smaller?) size
pub fn complex_bitcast<'ctx>(
builder: &Builder<'ctx>,
from_value: BasicValueEnum<'ctx>,
to_type: BasicTypeEnum<'ctx>,
name: &str,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
use inkwell::types::BasicType; use inkwell::types::BasicType;
@ -2135,7 +2193,7 @@ pub fn cast_basic_basic<'ctx>(
.build_bitcast( .build_bitcast(
argument_pointer, argument_pointer,
to_type.ptr_type(inkwell::AddressSpace::Generic), to_type.ptr_type(inkwell::AddressSpace::Generic),
"cast_basic_basic", name,
) )
.into_pointer_value(); .into_pointer_value();
@ -2150,7 +2208,12 @@ fn extract_tag_discriminant_struct<'a, 'ctx, 'env>(
.context .context
.struct_type(&[env.context.i64_type().into()], false); .struct_type(&[env.context.i64_type().into()], false);
let struct_value = cast_struct_struct(env.builder, from_value, struct_type); let struct_value = complex_bitcast_struct_struct(
env.builder,
from_value,
struct_type,
"extract_tag_discriminant_struct",
);
env.builder env.builder
.build_extract_value(struct_value, 0, "") .build_extract_value(struct_value, 0, "")
@ -2219,6 +2282,8 @@ fn build_switch_ir<'a, 'ctx, 'env>(
let scope = &mut copy; let scope = &mut copy;
let cond_symbol = &cond_symbol; let cond_symbol = &cond_symbol;
let (cond_value, stored_layout) = load_symbol_and_layout(env, scope, cond_symbol);
debug_assert_eq!(&cond_layout, stored_layout);
let cont_block = context.append_basic_block(parent, "cont"); let cont_block = context.append_basic_block(parent, "cont");
@ -2227,19 +2292,17 @@ fn build_switch_ir<'a, 'ctx, 'env>(
Layout::Builtin(Builtin::Float64) => { Layout::Builtin(Builtin::Float64) => {
// float matches are done on the bit pattern // float matches are done on the bit pattern
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
let full_cond = load_symbol(env, scope, cond_symbol);
builder builder
.build_bitcast(full_cond, env.context.i64_type(), "") .build_bitcast(cond_value, env.context.i64_type(), "")
.into_int_value() .into_int_value()
} }
Layout::Builtin(Builtin::Float32) => { Layout::Builtin(Builtin::Float32) => {
// float matches are done on the bit pattern // float matches are done on the bit pattern
cond_layout = Layout::Builtin(Builtin::Int32); cond_layout = Layout::Builtin(Builtin::Int32);
let full_cond = load_symbol(env, scope, cond_symbol);
builder builder
.build_bitcast(full_cond, env.context.i32_type(), "") .build_bitcast(cond_value, env.context.i32_type(), "")
.into_int_value() .into_int_value()
} }
Layout::Union(variant) => { Layout::Union(variant) => {
@ -2249,7 +2312,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
NonRecursive(_) => { NonRecursive(_) => {
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
let full_cond = load_symbol(env, scope, cond_symbol).into_struct_value(); let full_cond = cond_value.into_struct_value();
extract_tag_discriminant_struct(env, full_cond) extract_tag_discriminant_struct(env, full_cond)
} }
@ -2257,21 +2320,13 @@ fn build_switch_ir<'a, 'ctx, 'env>(
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
use BasicValueEnum::*; debug_assert!(cond_value.is_pointer_value());
match load_symbol(env, scope, cond_symbol) { extract_tag_discriminant_ptr(env, cond_value.into_pointer_value())
PointerValue(full_cond_ptr) => {
extract_tag_discriminant_ptr(env, full_cond_ptr)
}
StructValue(full_cond_struct) => {
extract_tag_discriminant_struct(env, full_cond_struct)
}
_ => unreachable!(),
}
} }
NullableWrapped { nullable_id, .. } => { NullableWrapped { nullable_id, .. } => {
// we match on the discriminant, not the whole Tag // we match on the discriminant, not the whole Tag
cond_layout = Layout::Builtin(Builtin::Int64); cond_layout = Layout::Builtin(Builtin::Int64);
let full_cond_ptr = load_symbol(env, scope, cond_symbol).into_pointer_value(); let full_cond_ptr = cond_value.into_pointer_value();
let comparison: IntValue = let comparison: IntValue =
env.builder.build_is_null(full_cond_ptr, "is_null_cond"); env.builder.build_is_null(full_cond_ptr, "is_null_cond");
@ -2302,7 +2357,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
} }
} }
} }
Layout::Builtin(_) => load_symbol(env, scope, cond_symbol).into_int_value(), Layout::Builtin(_) => cond_value.into_int_value(),
other => todo!("Build switch value from layout: {:?}", other), other => todo!("Build switch value from layout: {:?}", other),
}; };

View file

@ -1,5 +1,5 @@
use crate::llvm::build::{ use crate::llvm::build::{
cast_basic_basic, cast_struct_struct, create_entry_block_alloca, set_name, Env, Scope, cast_basic_basic, cast_block_of_memory_to_tag, create_entry_block_alloca, set_name, Env, Scope,
FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64, FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64,
}; };
use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list}; use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list};
@ -45,12 +45,14 @@ impl<'ctx> PointerToRefcount<'ctx> {
// must make sure it's a pointer to usize // must make sure it's a pointer to usize
let refcount_type = ptr_int(env.context, env.ptr_bytes); let refcount_type = ptr_int(env.context, env.ptr_bytes);
let value = cast_basic_basic( let value = env
env.builder, .builder
ptr.into(), .build_bitcast(
refcount_type.ptr_type(AddressSpace::Generic).into(), ptr,
) refcount_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "to_refcount_ptr",
)
.into_pointer_value();
Self { value } Self { value }
} }
@ -64,7 +66,8 @@ impl<'ctx> PointerToRefcount<'ctx> {
let refcount_type = ptr_int(env.context, env.ptr_bytes); let refcount_type = ptr_int(env.context, env.ptr_bytes);
let refcount_ptr_type = refcount_type.ptr_type(AddressSpace::Generic); let refcount_ptr_type = refcount_type.ptr_type(AddressSpace::Generic);
let ptr_as_usize_ptr = cast_basic_basic(builder, data_ptr.into(), refcount_ptr_type.into()) let ptr_as_usize_ptr = builder
.build_bitcast(data_ptr, refcount_ptr_type, "as_usize_ptr")
.into_pointer_value(); .into_pointer_value();
// get a pointer to index -1 // get a pointer to index -1
@ -1232,12 +1235,14 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>(
); );
// cast the opaque pointer to a pointer of the correct shape // cast the opaque pointer to a pointer of the correct shape
let struct_ptr = cast_basic_basic( let struct_ptr = env
env.builder, .builder
value_ptr.into(), .build_bitcast(
wrapper_type.ptr_type(AddressSpace::Generic).into(), value_ptr,
) wrapper_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "opaque_to_correct",
)
.into_pointer_value();
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1428,8 +1433,8 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
env.ptr_bytes, env.ptr_bytes,
); );
let wrapper_struct = debug_assert!(wrapper_type.is_struct_type());
cast_struct_struct(env.builder, wrapper_struct, wrapper_type.into_struct_type()); let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type);
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1528,15 +1533,11 @@ fn rec_union_read_tag<'a, 'ctx, 'env>(
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
// Assumption: the tag is the first thing stored // Assumption: the tag is the first thing stored
// so cast the pointer to the data to a `i64*` // so cast the pointer to the data to a `i64*`
let tag_ptr = cast_basic_basic( let tag_ptr_type = env.context.i64_type().ptr_type(AddressSpace::Generic);
env.builder, let tag_ptr = env
value_ptr.into(), .builder
env.context .build_bitcast(value_ptr, tag_ptr_type, "cast_tag_ptr")
.i64_type() .into_pointer_value();
.ptr_type(AddressSpace::Generic)
.into(),
)
.into_pointer_value();
env.builder env.builder
.build_load(tag_ptr, "load_tag_id") .build_load(tag_ptr, "load_tag_id")
@ -1634,12 +1635,14 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>(
); );
// cast the opaque pointer to a pointer of the correct shape // cast the opaque pointer to a pointer of the correct shape
let struct_ptr = cast_basic_basic( let struct_ptr = env
env.builder, .builder
value_ptr.into(), .build_bitcast(
wrapper_type.ptr_type(AddressSpace::Generic).into(), value_ptr,
) wrapper_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "opaque_to_correct",
)
.into_pointer_value();
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1657,10 +1660,10 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>(
// therefore we must cast it to our desired type // therefore we must cast it to our desired type
let union_type = block_of_memory_slices(env.context, tags, env.ptr_bytes); let union_type = block_of_memory_slices(env.context, tags, env.ptr_bytes);
let recursive_field_ptr = cast_basic_basic( let recursive_field_ptr = env.builder.build_bitcast(
env.builder,
ptr_as_i64_ptr, ptr_as_i64_ptr,
union_type.ptr_type(AddressSpace::Generic).into(), union_type.ptr_type(AddressSpace::Generic),
"recursive_to_desired",
); );
// recursively increment the field // recursively increment the field
@ -1694,10 +1697,11 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>(
// read the tag_id // read the tag_id
let tag_id = rec_union_read_tag(env, value_ptr); let tag_id = rec_union_read_tag(env, value_ptr);
let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); let tag_id_u8 = env
.builder
.build_int_cast(tag_id, env.context.i8_type(), "tag_id_u8");
env.builder env.builder.build_switch(tag_id_u8, merge_block, &cases);
.build_switch(tag_id_u8.into_int_value(), merge_block, &cases);
env.builder.position_at_end(merge_block); env.builder.position_at_end(merge_block);
@ -1808,7 +1812,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
.into_int_value() .into_int_value()
}; };
let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); let tag_id_u8 = env
.builder
.build_int_cast(tag_id, env.context.i8_type(), "tag_id_u8");
// next, make a jump table for all possible values of the tag_id // next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena); let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
@ -1834,8 +1840,8 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
env.ptr_bytes, env.ptr_bytes,
); );
let wrapper_struct = debug_assert!(wrapper_type.is_struct_type());
cast_struct_struct(env.builder, wrapper_struct, wrapper_type.into_struct_type()); let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type);
for (i, field_layout) in field_layouts.iter().enumerate() { for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout { if let Layout::RecursivePointer = field_layout {
@ -1849,12 +1855,14 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
// therefore we must cast it to our desired type // therefore we must cast it to our desired type
let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); let union_type = block_of_memory(env.context, &layout, env.ptr_bytes);
let recursive_field_ptr = cast_basic_basic( let recursive_field_ptr = env
env.builder, .builder
ptr_as_i64_ptr, .build_bitcast(
union_type.ptr_type(AddressSpace::Generic).into(), ptr_as_i64_ptr,
) union_type.ptr_type(AddressSpace::Generic),
.into_pointer_value(); "recursive_to_desired",
)
.into_pointer_value();
let recursive_field = env let recursive_field = env
.builder .builder
@ -1889,8 +1897,7 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
env.builder.position_at_end(before_block); env.builder.position_at_end(before_block);
env.builder env.builder.build_switch(tag_id_u8, merge_block, &cases);
.build_switch(tag_id_u8.into_int_value(), merge_block, &cases);
env.builder.position_at_end(merge_block); env.builder.position_at_end(merge_block);

View file

@ -2074,4 +2074,46 @@ mod gen_primitives {
i64 i64
); );
} }
#[test]
fn bug_exposer() {
// the decision tree will generate a jump to the `1` branch here
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
Expr : [ ZAdd Expr Expr, Val I64, Var I64 ]
eval : Expr -> I64
eval = \e ->
when e is
Var _ -> 0
Val v -> v
ZAdd l r -> eval l + eval r
constFolding : Expr -> Expr
constFolding = \e ->
when e is
ZAdd e1 e2 ->
when Pair e1 e2 is
Pair (Val a) (Val b) -> Val (a+b)
Pair (Val a) (ZAdd x (Val b)) -> ZAdd (Val (a+b)) x
Pair _ _ -> ZAdd e1 e2
_ -> e
expr : Expr
expr = ZAdd (Val 3) (ZAdd (Val 4) (Val 5))
main : I64
main = eval (constFolding expr)
"#
),
12,
i64
);
}
} }

View file

@ -61,10 +61,10 @@ mod gen_str {
when List.first (Str.split "JJJ" "JJJJ there") is when List.first (Str.split "JJJ" "JJJJ there") is
Ok str -> Ok str ->
Str.countGraphemes str Str.countGraphemes str
_ -> _ ->
-1 -1
"# "#
), ),
3, 3,
@ -84,10 +84,10 @@ mod gen_str {
|> Str.concat str |> Str.concat str
|> Str.concat str |> Str.concat str
|> Str.concat str |> Str.concat str
_ -> _ ->
"Not Str!" "Not Str!"
"# "#
), ),
"JJJJJJJJJJJJJJJJJJJJJJJJJ", "JJJJJJJJJJJJJJJJJJJJJJJJJ",
@ -103,7 +103,7 @@ mod gen_str {
when when
List.first List.first
(Str.split "JJJ" "0123456789abcdefghi") (Str.split "JJJ" "0123456789abcdefghi")
is is
Ok str -> str Ok str -> str
_ -> "" _ -> ""
"# "#
@ -118,7 +118,7 @@ mod gen_str {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
Str.split "01234567789abcdefghi?01234567789abcdefghi" "?" Str.split "01234567789abcdefghi?01234567789abcdefghi" "?"
"# "#
), ),
&["01234567789abcdefghi", "01234567789abcdefghi"], &["01234567789abcdefghi", "01234567789abcdefghi"],
@ -128,7 +128,7 @@ mod gen_str {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
Str.split "01234567789abcdefghi 3ch 01234567789abcdefghi" "3ch" Str.split "01234567789abcdefghi 3ch 01234567789abcdefghi" "3ch"
"# "#
), ),
&["01234567789abcdefghi ", " 01234567789abcdefghi"], &["01234567789abcdefghi ", " 01234567789abcdefghi"],
@ -154,8 +154,8 @@ mod gen_str {
assert_evals_to!( assert_evals_to!(
indoc!( indoc!(
r#" r#"
Str.split Str.split
"string to split is shorter" "string to split is shorter"
"than the delimiter which happens to be very very long" "than the delimiter which happens to be very very long"
"# "#
), ),
@ -538,4 +538,34 @@ mod gen_str {
debug_assert_eq!(short.clone(), short); debug_assert_eq!(short.clone(), short);
debug_assert_eq!(empty.clone(), empty); debug_assert_eq!(empty.clone(), empty);
} }
#[test]
fn nested_recursive_literal() {
assert_evals_to!(
indoc!(
r#"
Expr : [ Add Expr Expr, Val I64, Var I64 ]
expr : Expr
expr = Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1))
printExpr : Expr -> Str
printExpr = \e ->
when e is
Add a b ->
"Add ("
|> Str.concat (printExpr a)
|> Str.concat ") ("
|> Str.concat (printExpr b)
|> Str.concat ")"
Val v -> "Val " |> Str.concat (Str.fromInt v)
Var v -> "Var " |> Str.concat (Str.fromInt v)
printExpr expr
"#
),
"Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1))",
&'static str
);
}
} }

View file

@ -941,4 +941,23 @@ mod gen_tags {
i64 i64
); );
} }
#[test]
fn nested_recursive_literal() {
assert_evals_to!(
indoc!(
r"#
Expr : [ Add Expr Expr, Val I64, Var I64 ]
e : Expr
e = Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1))
e
#"
),
0,
&i64,
|x: &i64| *x
);
}
} }

View file

99
examples/task/CFold.roc Normal file
View file

@ -0,0 +1,99 @@
app "cfold"
packages { base: "thing/platform-dir" }
imports [base.Task]
provides [ main ] to base
# adapted from https://github.com/koka-lang/koka/blob/master/test/bench/haskell/cfold.hs
main : Task.Task {} []
main =
e = mkExpr 3 1
unoptimized = eval e
optimized = eval (constFolding (reassoc e))
unoptimized
|> Str.fromInt
|> Str.concat " & "
|> Str.concat (Str.fromInt optimized)
|> Task.putLine
Expr : [
Add Expr Expr,
Mul Expr Expr,
Val I64,
Var I64
]
mkExpr : I64, I64 -> Expr
mkExpr = \n , v ->
when n is
0 -> if v == 0 then Var 1 else Val v
_ -> Add (mkExpr (n-1) (v+1)) (mkExpr (n-1) (max (v-1) 0))
max : I64, I64 -> I64
max = \a, b -> if a > b then a else b
appendAdd : Expr, Expr -> Expr
appendAdd = \e1, e2 ->
when e1 is
Add a1 a2 -> Add a1 (appendAdd a2 e2)
_ -> Add e1 e2
appendMul : Expr, Expr -> Expr
appendMul = \e1, e2 ->
when e1 is
Mul a1 a2 -> Mul a1 (appendMul a2 e2)
_ -> Mul e1 e2
eval : Expr -> I64
eval = \e ->
when e is
Var _ -> 0
Val v -> v
Add l r -> eval l + eval r
Mul l r -> eval l * eval r
reassoc : Expr -> Expr
reassoc = \e ->
when e is
Add e1 e2 ->
x1 = reassoc e1
x2 = reassoc e2
appendAdd x1 x2
Mul e1 e2 ->
x1 = reassoc e1
x2 = reassoc e2
appendMul x1 x2
_ -> e
constFolding : Expr -> Expr
constFolding = \e ->
when e is
Add e1 e2 ->
x1 = constFolding e1
x2 = constFolding e2
when Pair x1 x2 is
Pair (Val a) (Val b) -> Val (a+b)
# Pair (Val a) (Add (Val b) x) -> Add (Val (a+b)) x
Pair (Val a) (Add x (Val b)) -> Add (Val (a+b)) x
Pair _ _ -> Add x1 x2
Mul e1 e2 ->
x1 = constFolding e1
x2 = constFolding e2
when Pair x1 x2 is
Pair (Val a) (Val b) -> Val (a*b)
Pair (Val a) (Mul (Val b) x) -> Mul (Val (a*b)) x
Pair (Val a) (Mul x (Val b)) -> Mul (Val (a*b)) x
Pair _ _ -> Mul x1 x2
_ -> e

54
examples/task/NQueens.roc Normal file
View file

@ -0,0 +1,54 @@
app "nqueens"
packages { base: "thing/platform-dir" }
imports [base.Task]
provides [ main ] to base
main : Task.Task {} []
main =
queens 10
|> Str.fromInt
|> Task.putLine
ConsList a : [ Nil, Cons a (ConsList a) ]
queens = \n -> length (findSolutions n n)
length : ConsList a -> I64
length = \xs -> lengthHelp xs 0
lengthHelp : ConsList a, I64 -> I64
lengthHelp = \xs, acc ->
when xs is
Nil -> acc
Cons _ rest -> lengthHelp rest (1 + acc)
safe : I64, I64, ConsList I64 -> Bool
safe = \queen, diagonal, xs ->
when xs is
Nil ->
True
Cons q t ->
queen != q && queen != q + diagonal && queen != q - diagonal && safe queen (diagonal + 1) t
appendSafe : I64, ConsList I64, ConsList (ConsList I64) -> ConsList (ConsList I64)
appendSafe = \k, soln, solns ->
if k <= 0 then
solns
else if safe k 1 soln then
appendSafe (k - 1) soln (Cons (Cons k soln) solns)
else
appendSafe (k - 1) soln solns
extend = \n, acc, solns ->
when solns is
Nil -> acc
Cons soln rest -> extend n (appendSafe n soln acc) rest
findSolutions = \n, k ->
if k == 0 then
Cons Nil Nil
else
extend n Nil (findSolutions n (k - 1))