diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 0863d4ebdc..3dfcf48c5e 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -952,11 +952,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( let internal_type = basic_type_from_layout(env.arena, env.context, &tag_layout, env.ptr_bytes); - cast_basic_basic( - builder, - struct_val.into_struct_value().into(), - internal_type, - ) + cast_tag_to_block_of_memory(builder, struct_val.into_struct_value(), internal_type) } Tag { arguments, @@ -1001,10 +997,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( debug_assert!(val.is_pointer_value()); // we store recursive pointers as `i64*` - let ptr = cast_basic_basic( - builder, + let ptr = env.builder.build_bitcast( val, - ctx.i64_type().ptr_type(AddressSpace::Generic).into(), + ctx.i64_type().ptr_type(AddressSpace::Generic), + "cast_recursive_pointer", ); field_vals.push(ptr); @@ -1020,12 +1016,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( // Create the struct_type let data_ptr = reserve_with_refcount(env, &tag_layout); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); - let struct_ptr = cast_basic_basic( - builder, - data_ptr.into(), - struct_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let struct_ptr = env + .builder + .build_bitcast( + data_ptr, + struct_type.ptr_type(AddressSpace::Generic), + "block_of_memory_to_tag", + ) + .into_pointer_value(); // Insert field exprs into struct_val for (index, field_val) in field_vals.into_iter().enumerate() { @@ -1098,10 +1096,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( debug_assert!(val.is_pointer_value()); // we store recursive pointers as `i64*` - let ptr = cast_basic_basic( - builder, + let ptr = env.builder.build_bitcast( val, - ctx.i64_type().ptr_type(AddressSpace::Generic).into(), + ctx.i64_type().ptr_type(AddressSpace::Generic), + "cast_recursive_pointer", ); field_vals.push(ptr); @@ -1117,12 +1115,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( // Create the struct_type let data_ptr = reserve_with_refcount(env, &tag_layout); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); - let struct_ptr = cast_basic_basic( - builder, - data_ptr.into(), - struct_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let struct_ptr = env + .builder + .build_bitcast( + data_ptr, + struct_type.ptr_type(AddressSpace::Generic), + "block_of_memory_to_tag", + ) + .into_pointer_value(); // Insert field exprs into struct_val for (index, field_val) in field_vals.into_iter().enumerate() { @@ -1197,10 +1197,10 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( debug_assert!(val.is_pointer_value()); // we store recursive pointers as `i64*` - let ptr = cast_basic_basic( - builder, + let ptr = env.builder.build_bitcast( val, - ctx.i64_type().ptr_type(AddressSpace::Generic).into(), + ctx.i64_type().ptr_type(AddressSpace::Generic), + "cast_recursive_pointer", ); field_vals.push(ptr); @@ -1220,12 +1220,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( ); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); - let struct_ptr = cast_basic_basic( - builder, - data_ptr.into(), - struct_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let struct_ptr = env + .builder + .build_bitcast( + data_ptr, + struct_type.ptr_type(AddressSpace::Generic), + "block_of_memory_to_tag", + ) + .into_pointer_value(); // Insert field exprs into struct_val for (index, field_val) in field_vals.into_iter().enumerate() { @@ -1331,7 +1333,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( .context .struct_type(field_types.into_bump_slice(), false); - let struct_value = cast_struct_struct(builder, value, struct_type); + let struct_value = access_index_struct_value(builder, value, struct_type); let result = builder .build_extract_value(struct_value, *index as u32, "") @@ -1342,11 +1344,12 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( block_of_memory(env.context, &struct_layout, env.ptr_bytes); // the value is a pointer to the actual value; load that value! - let ptr = cast_basic_basic( - builder, + let ptr = env.builder.build_bitcast( result, - desired_type.ptr_type(AddressSpace::Generic).into(), + desired_type.ptr_type(AddressSpace::Generic), + "cast_struct_value_pointer", ); + builder.build_load(ptr.into_pointer_value(), "load_recursive_field") } else { result @@ -1494,12 +1497,14 @@ fn lookup_at_index_ptr<'a, 'ctx, 'env>( use inkwell::types::BasicType; let builder = env.builder; - let ptr = cast_basic_basic( - builder, - value.into(), - struct_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let ptr = env + .builder + .build_bitcast( + value, + struct_type.ptr_type(AddressSpace::Generic), + "cast_lookup_at_index_ptr", + ) + .into_pointer_value(); let elem_ptr = builder .build_struct_gep(ptr, index as u32, "at_index_struct_gep") @@ -1510,12 +1515,11 @@ fn lookup_at_index_ptr<'a, 'ctx, 'env>( if let Some(Layout::RecursivePointer) = field_layouts.get(index as usize) { // a recursive field is stored as a `i64*`, to use it we must cast it to // a pointer to the block of memory representation - cast_basic_basic( - builder, + builder.build_bitcast( result, block_of_memory(env.context, &struct_layout, env.ptr_bytes) - .ptr_type(AddressSpace::Generic) - .into(), + .ptr_type(AddressSpace::Generic), + "cast_rec_pointer_lookup_at_index_ptr", ) } else { result @@ -1583,12 +1587,13 @@ pub fn allocate_with_refcount_help<'a, 'ctx, 'env>( // We must return a pointer to the first element: let data_ptr = { let int_type = ptr_int(ctx, env.ptr_bytes); - let as_usize_ptr = cast_basic_basic( - env.builder, - ptr.into(), - int_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let as_usize_ptr = builder + .build_bitcast( + ptr, + int_type.ptr_type(AddressSpace::Generic), + "to_usize_ptr", + ) + .into_pointer_value(); let index = match extra_bytes { n if n == env.ptr_bytes => 1, @@ -1601,14 +1606,17 @@ pub fn allocate_with_refcount_help<'a, 'ctx, 'env>( let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic); unsafe { - cast_basic_basic( - env.builder, - env.builder - .build_in_bounds_gep(as_usize_ptr, &[index_intvalue], "get_data_ptr") - .into(), - ptr_type.into(), - ) - .into_pointer_value() + builder + .build_bitcast( + env.builder.build_in_bounds_gep( + as_usize_ptr, + &[index_intvalue], + "get_data_ptr", + ), + ptr_type, + "malloc_cast_to_desired", + ) + .into_pointer_value() } }; @@ -1664,7 +1672,7 @@ fn list_literal<'a, 'ctx, 'env>( let ptr_bytes = env.ptr_bytes; let u8_ptr_type = ctx.i8_type().ptr_type(AddressSpace::Generic); - let generic_ptr = cast_basic_basic(builder, ptr.into(), u8_ptr_type.into()); + let generic_ptr = builder.build_bitcast(ptr, u8_ptr_type, "to_generic_ptr"); let struct_type = collection(ctx, ptr_bytes); let len = BasicValueEnum::IntValue(env.ptr_int().const_int(len_u64, false)); @@ -2104,14 +2112,18 @@ pub fn load_symbol_and_layout<'a, 'ctx, 'env, 'b>( None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope), } } - -/// Cast a struct to another struct of the same (or smaller?) size -pub fn cast_struct_struct<'ctx>( +fn access_index_struct_value<'ctx>( builder: &Builder<'ctx>, from_value: StructValue<'ctx>, to_type: StructType<'ctx>, ) -> StructValue<'ctx> { - cast_basic_basic(builder, from_value.into(), to_type.into()).into_struct_value() + complex_bitcast( + builder, + from_value.into(), + to_type.into(), + "access_index_struct_value", + ) + .into_struct_value() } /// Cast a value to another value of the same (or smaller?) size @@ -2119,6 +2131,52 @@ pub fn cast_basic_basic<'ctx>( builder: &Builder<'ctx>, from_value: BasicValueEnum<'ctx>, to_type: BasicTypeEnum<'ctx>, +) -> BasicValueEnum<'ctx> { + complex_bitcast(builder, from_value, to_type, "cast_basic_basic") +} + +pub fn complex_bitcast_struct_struct<'ctx>( + builder: &Builder<'ctx>, + from_value: StructValue<'ctx>, + to_type: StructType<'ctx>, + name: &str, +) -> StructValue<'ctx> { + complex_bitcast(builder, from_value.into(), to_type.into(), name).into_struct_value() +} + +fn cast_tag_to_block_of_memory<'ctx>( + builder: &Builder<'ctx>, + from_value: StructValue<'ctx>, + to_type: BasicTypeEnum<'ctx>, +) -> BasicValueEnum<'ctx> { + complex_bitcast( + builder, + from_value.into(), + to_type, + "tag_to_block_of_memory", + ) +} + +pub fn cast_block_of_memory_to_tag<'ctx>( + builder: &Builder<'ctx>, + from_value: StructValue<'ctx>, + to_type: BasicTypeEnum<'ctx>, +) -> StructValue<'ctx> { + complex_bitcast( + builder, + from_value.into(), + to_type, + "block_of_memory_to_tag", + ) + .into_struct_value() +} + +/// Cast a value to another value of the same (or smaller?) size +pub fn complex_bitcast<'ctx>( + builder: &Builder<'ctx>, + from_value: BasicValueEnum<'ctx>, + to_type: BasicTypeEnum<'ctx>, + name: &str, ) -> BasicValueEnum<'ctx> { use inkwell::types::BasicType; @@ -2135,7 +2193,7 @@ pub fn cast_basic_basic<'ctx>( .build_bitcast( argument_pointer, to_type.ptr_type(inkwell::AddressSpace::Generic), - "cast_basic_basic", + name, ) .into_pointer_value(); @@ -2150,7 +2208,12 @@ fn extract_tag_discriminant_struct<'a, 'ctx, 'env>( .context .struct_type(&[env.context.i64_type().into()], false); - let struct_value = cast_struct_struct(env.builder, from_value, struct_type); + let struct_value = complex_bitcast_struct_struct( + env.builder, + from_value, + struct_type, + "extract_tag_discriminant_struct", + ); env.builder .build_extract_value(struct_value, 0, "") @@ -2219,6 +2282,8 @@ fn build_switch_ir<'a, 'ctx, 'env>( let scope = &mut copy; let cond_symbol = &cond_symbol; + let (cond_value, stored_layout) = load_symbol_and_layout(env, scope, cond_symbol); + debug_assert_eq!(&cond_layout, stored_layout); let cont_block = context.append_basic_block(parent, "cont"); @@ -2227,19 +2292,17 @@ fn build_switch_ir<'a, 'ctx, 'env>( Layout::Builtin(Builtin::Float64) => { // float matches are done on the bit pattern cond_layout = Layout::Builtin(Builtin::Int64); - let full_cond = load_symbol(env, scope, cond_symbol); builder - .build_bitcast(full_cond, env.context.i64_type(), "") + .build_bitcast(cond_value, env.context.i64_type(), "") .into_int_value() } Layout::Builtin(Builtin::Float32) => { // float matches are done on the bit pattern cond_layout = Layout::Builtin(Builtin::Int32); - let full_cond = load_symbol(env, scope, cond_symbol); builder - .build_bitcast(full_cond, env.context.i32_type(), "") + .build_bitcast(cond_value, env.context.i32_type(), "") .into_int_value() } Layout::Union(variant) => { @@ -2249,7 +2312,7 @@ fn build_switch_ir<'a, 'ctx, 'env>( NonRecursive(_) => { // we match on the discriminant, not the whole Tag cond_layout = Layout::Builtin(Builtin::Int64); - let full_cond = load_symbol(env, scope, cond_symbol).into_struct_value(); + let full_cond = cond_value.into_struct_value(); extract_tag_discriminant_struct(env, full_cond) } @@ -2257,21 +2320,13 @@ fn build_switch_ir<'a, 'ctx, 'env>( // we match on the discriminant, not the whole Tag cond_layout = Layout::Builtin(Builtin::Int64); - use BasicValueEnum::*; - match load_symbol(env, scope, cond_symbol) { - PointerValue(full_cond_ptr) => { - extract_tag_discriminant_ptr(env, full_cond_ptr) - } - StructValue(full_cond_struct) => { - extract_tag_discriminant_struct(env, full_cond_struct) - } - _ => unreachable!(), - } + debug_assert!(cond_value.is_pointer_value()); + extract_tag_discriminant_ptr(env, cond_value.into_pointer_value()) } NullableWrapped { nullable_id, .. } => { // we match on the discriminant, not the whole Tag cond_layout = Layout::Builtin(Builtin::Int64); - let full_cond_ptr = load_symbol(env, scope, cond_symbol).into_pointer_value(); + let full_cond_ptr = cond_value.into_pointer_value(); let comparison: IntValue = env.builder.build_is_null(full_cond_ptr, "is_null_cond"); @@ -2302,7 +2357,7 @@ fn build_switch_ir<'a, 'ctx, 'env>( } } } - Layout::Builtin(_) => load_symbol(env, scope, cond_symbol).into_int_value(), + Layout::Builtin(_) => cond_value.into_int_value(), other => todo!("Build switch value from layout: {:?}", other), }; diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index af15954af6..8aaa7f5314 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -1,5 +1,5 @@ use crate::llvm::build::{ - cast_basic_basic, cast_struct_struct, create_entry_block_alloca, set_name, Env, Scope, + cast_basic_basic, cast_block_of_memory_to_tag, create_entry_block_alloca, set_name, Env, Scope, FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64, }; use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list}; @@ -45,12 +45,14 @@ impl<'ctx> PointerToRefcount<'ctx> { // must make sure it's a pointer to usize let refcount_type = ptr_int(env.context, env.ptr_bytes); - let value = cast_basic_basic( - env.builder, - ptr.into(), - refcount_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let value = env + .builder + .build_bitcast( + ptr, + refcount_type.ptr_type(AddressSpace::Generic), + "to_refcount_ptr", + ) + .into_pointer_value(); Self { value } } @@ -64,7 +66,8 @@ impl<'ctx> PointerToRefcount<'ctx> { let refcount_type = ptr_int(env.context, env.ptr_bytes); let refcount_ptr_type = refcount_type.ptr_type(AddressSpace::Generic); - let ptr_as_usize_ptr = cast_basic_basic(builder, data_ptr.into(), refcount_ptr_type.into()) + let ptr_as_usize_ptr = builder + .build_bitcast(data_ptr, refcount_ptr_type, "as_usize_ptr") .into_pointer_value(); // get a pointer to index -1 @@ -1232,12 +1235,14 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( ); // cast the opaque pointer to a pointer of the correct shape - let struct_ptr = cast_basic_basic( - env.builder, - value_ptr.into(), - wrapper_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let struct_ptr = env + .builder + .build_bitcast( + value_ptr, + wrapper_type.ptr_type(AddressSpace::Generic), + "opaque_to_correct", + ) + .into_pointer_value(); for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { @@ -1428,8 +1433,8 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>( env.ptr_bytes, ); - let wrapper_struct = - cast_struct_struct(env.builder, wrapper_struct, wrapper_type.into_struct_type()); + debug_assert!(wrapper_type.is_struct_type()); + let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type); for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { @@ -1528,15 +1533,11 @@ fn rec_union_read_tag<'a, 'ctx, 'env>( ) -> IntValue<'ctx> { // Assumption: the tag is the first thing stored // so cast the pointer to the data to a `i64*` - let tag_ptr = cast_basic_basic( - env.builder, - value_ptr.into(), - env.context - .i64_type() - .ptr_type(AddressSpace::Generic) - .into(), - ) - .into_pointer_value(); + let tag_ptr_type = env.context.i64_type().ptr_type(AddressSpace::Generic); + let tag_ptr = env + .builder + .build_bitcast(value_ptr, tag_ptr_type, "cast_tag_ptr") + .into_pointer_value(); env.builder .build_load(tag_ptr, "load_tag_id") @@ -1634,12 +1635,14 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>( ); // cast the opaque pointer to a pointer of the correct shape - let struct_ptr = cast_basic_basic( - env.builder, - value_ptr.into(), - wrapper_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let struct_ptr = env + .builder + .build_bitcast( + value_ptr, + wrapper_type.ptr_type(AddressSpace::Generic), + "opaque_to_correct", + ) + .into_pointer_value(); for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { @@ -1657,10 +1660,10 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>( // therefore we must cast it to our desired type let union_type = block_of_memory_slices(env.context, tags, env.ptr_bytes); - let recursive_field_ptr = cast_basic_basic( - env.builder, + let recursive_field_ptr = env.builder.build_bitcast( ptr_as_i64_ptr, - union_type.ptr_type(AddressSpace::Generic).into(), + union_type.ptr_type(AddressSpace::Generic), + "recursive_to_desired", ); // recursively increment the field @@ -1694,10 +1697,11 @@ pub fn build_inc_rec_union_help<'a, 'ctx, 'env>( // read the tag_id let tag_id = rec_union_read_tag(env, value_ptr); - let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); + let tag_id_u8 = env + .builder + .build_int_cast(tag_id, env.context.i8_type(), "tag_id_u8"); - env.builder - .build_switch(tag_id_u8.into_int_value(), merge_block, &cases); + env.builder.build_switch(tag_id_u8, merge_block, &cases); env.builder.position_at_end(merge_block); @@ -1808,7 +1812,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( .into_int_value() }; - let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into()); + let tag_id_u8 = env + .builder + .build_int_cast(tag_id, env.context.i8_type(), "tag_id_u8"); // next, make a jump table for all possible values of the tag_id let mut cases = Vec::with_capacity_in(tags.len(), env.arena); @@ -1834,8 +1840,8 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( env.ptr_bytes, ); - let wrapper_struct = - cast_struct_struct(env.builder, wrapper_struct, wrapper_type.into_struct_type()); + debug_assert!(wrapper_type.is_struct_type()); + let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type); for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { @@ -1849,12 +1855,14 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( // therefore we must cast it to our desired type let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let recursive_field_ptr = cast_basic_basic( - env.builder, - ptr_as_i64_ptr, - union_type.ptr_type(AddressSpace::Generic).into(), - ) - .into_pointer_value(); + let recursive_field_ptr = env + .builder + .build_bitcast( + ptr_as_i64_ptr, + union_type.ptr_type(AddressSpace::Generic), + "recursive_to_desired", + ) + .into_pointer_value(); let recursive_field = env .builder @@ -1889,8 +1897,7 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( env.builder.position_at_end(before_block); - env.builder - .build_switch(tag_id_u8.into_int_value(), merge_block, &cases); + env.builder.build_switch(tag_id_u8, merge_block, &cases); env.builder.position_at_end(merge_block); diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index f9f02e0104..92ee4ff158 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -2074,4 +2074,46 @@ mod gen_primitives { i64 ); } + + #[test] + fn bug_exposer() { + // the decision tree will generate a jump to the `1` branch here + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + Expr : [ ZAdd Expr Expr, Val I64, Var I64 ] + + eval : Expr -> I64 + eval = \e -> + when e is + Var _ -> 0 + Val v -> v + ZAdd l r -> eval l + eval r + + constFolding : Expr -> Expr + constFolding = \e -> + when e is + ZAdd e1 e2 -> + when Pair e1 e2 is + Pair (Val a) (Val b) -> Val (a+b) + Pair (Val a) (ZAdd x (Val b)) -> ZAdd (Val (a+b)) x + Pair _ _ -> ZAdd e1 e2 + + + _ -> e + + + expr : Expr + expr = ZAdd (Val 3) (ZAdd (Val 4) (Val 5)) + + main : I64 + main = eval (constFolding expr) + "# + ), + 12, + i64 + ); + } } diff --git a/compiler/gen/tests/gen_str.rs b/compiler/gen/tests/gen_str.rs index 0f74d1fc02..fc224ae798 100644 --- a/compiler/gen/tests/gen_str.rs +++ b/compiler/gen/tests/gen_str.rs @@ -61,10 +61,10 @@ mod gen_str { when List.first (Str.split "JJJ" "JJJJ there") is Ok str -> Str.countGraphemes str - + _ -> -1 - + "# ), 3, @@ -84,10 +84,10 @@ mod gen_str { |> Str.concat str |> Str.concat str |> Str.concat str - + _ -> "Not Str!" - + "# ), "JJJJJJJJJJJJJJJJJJJJJJJJJ", @@ -103,7 +103,7 @@ mod gen_str { when List.first (Str.split "JJJ" "0123456789abcdefghi") - is + is Ok str -> str _ -> "" "# @@ -118,7 +118,7 @@ mod gen_str { assert_evals_to!( indoc!( r#" - Str.split "01234567789abcdefghi?01234567789abcdefghi" "?" + Str.split "01234567789abcdefghi?01234567789abcdefghi" "?" "# ), &["01234567789abcdefghi", "01234567789abcdefghi"], @@ -128,7 +128,7 @@ mod gen_str { assert_evals_to!( indoc!( r#" - Str.split "01234567789abcdefghi 3ch 01234567789abcdefghi" "3ch" + Str.split "01234567789abcdefghi 3ch 01234567789abcdefghi" "3ch" "# ), &["01234567789abcdefghi ", " 01234567789abcdefghi"], @@ -154,8 +154,8 @@ mod gen_str { assert_evals_to!( indoc!( r#" - Str.split - "string to split is shorter" + Str.split + "string to split is shorter" "than the delimiter which happens to be very very long" "# ), @@ -538,4 +538,34 @@ mod gen_str { debug_assert_eq!(short.clone(), short); debug_assert_eq!(empty.clone(), empty); } + + #[test] + fn nested_recursive_literal() { + assert_evals_to!( + indoc!( + r#" + Expr : [ Add Expr Expr, Val I64, Var I64 ] + + expr : Expr + expr = Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1)) + + printExpr : Expr -> Str + printExpr = \e -> + when e is + Add a b -> + "Add (" + |> Str.concat (printExpr a) + |> Str.concat ") (" + |> Str.concat (printExpr b) + |> Str.concat ")" + Val v -> "Val " |> Str.concat (Str.fromInt v) + Var v -> "Var " |> Str.concat (Str.fromInt v) + + printExpr expr + "# + ), + "Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1))", + &'static str + ); + } } diff --git a/compiler/gen/tests/gen_tags.rs b/compiler/gen/tests/gen_tags.rs index bf65dad6a8..6f7d0674d7 100644 --- a/compiler/gen/tests/gen_tags.rs +++ b/compiler/gen/tests/gen_tags.rs @@ -941,4 +941,23 @@ mod gen_tags { i64 ); } + + #[test] + fn nested_recursive_literal() { + assert_evals_to!( + indoc!( + r"# + Expr : [ Add Expr Expr, Val I64, Var I64 ] + + e : Expr + e = Add (Add (Val 3) (Val 1)) (Add (Val 1) (Var 1)) + + e + #" + ), + 0, + &i64, + |x: &i64| *x + ); + } } diff --git a/examples/effect/Deriv.roc b/examples/effect/Deriv.roc new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/task/CFold.roc b/examples/task/CFold.roc new file mode 100644 index 0000000000..cd8cfd1187 --- /dev/null +++ b/examples/task/CFold.roc @@ -0,0 +1,99 @@ +app "cfold" + packages { base: "thing/platform-dir" } + imports [base.Task] + provides [ main ] to base + +# adapted from https://github.com/koka-lang/koka/blob/master/test/bench/haskell/cfold.hs + +main : Task.Task {} [] +main = + e = mkExpr 3 1 + unoptimized = eval e + optimized = eval (constFolding (reassoc e)) + + unoptimized + |> Str.fromInt + |> Str.concat " & " + |> Str.concat (Str.fromInt optimized) + |> Task.putLine + +Expr : [ + Add Expr Expr, + Mul Expr Expr, + Val I64, + Var I64 + ] + +mkExpr : I64, I64 -> Expr +mkExpr = \n , v -> + when n is + 0 -> if v == 0 then Var 1 else Val v + _ -> Add (mkExpr (n-1) (v+1)) (mkExpr (n-1) (max (v-1) 0)) + +max : I64, I64 -> I64 +max = \a, b -> if a > b then a else b + + +appendAdd : Expr, Expr -> Expr +appendAdd = \e1, e2 -> + when e1 is + Add a1 a2 -> Add a1 (appendAdd a2 e2) + _ -> Add e1 e2 + +appendMul : Expr, Expr -> Expr +appendMul = \e1, e2 -> + when e1 is + Mul a1 a2 -> Mul a1 (appendMul a2 e2) + _ -> Mul e1 e2 + + +eval : Expr -> I64 +eval = \e -> + when e is + Var _ -> 0 + Val v -> v + Add l r -> eval l + eval r + Mul l r -> eval l * eval r + +reassoc : Expr -> Expr +reassoc = \e -> + when e is + Add e1 e2 -> + x1 = reassoc e1 + x2 = reassoc e2 + + appendAdd x1 x2 + + Mul e1 e2 -> + x1 = reassoc e1 + x2 = reassoc e2 + + appendMul x1 x2 + + _ -> e + +constFolding : Expr -> Expr +constFolding = \e -> + when e is + Add e1 e2 -> + x1 = constFolding e1 + x2 = constFolding e2 + + when Pair x1 x2 is + Pair (Val a) (Val b) -> Val (a+b) + # Pair (Val a) (Add (Val b) x) -> Add (Val (a+b)) x + Pair (Val a) (Add x (Val b)) -> Add (Val (a+b)) x + Pair _ _ -> Add x1 x2 + + Mul e1 e2 -> + x1 = constFolding e1 + x2 = constFolding e2 + + when Pair x1 x2 is + Pair (Val a) (Val b) -> Val (a*b) + Pair (Val a) (Mul (Val b) x) -> Mul (Val (a*b)) x + Pair (Val a) (Mul x (Val b)) -> Mul (Val (a*b)) x + Pair _ _ -> Mul x1 x2 + + _ -> e + diff --git a/examples/task/NQueens.roc b/examples/task/NQueens.roc new file mode 100644 index 0000000000..dce1ca3dea --- /dev/null +++ b/examples/task/NQueens.roc @@ -0,0 +1,54 @@ +app "nqueens" + packages { base: "thing/platform-dir" } + imports [base.Task] + provides [ main ] to base + +main : Task.Task {} [] +main = + queens 10 + |> Str.fromInt + |> Task.putLine + +ConsList a : [ Nil, Cons a (ConsList a) ] + +queens = \n -> length (findSolutions n n) + + +length : ConsList a -> I64 +length = \xs -> lengthHelp xs 0 + +lengthHelp : ConsList a, I64 -> I64 +lengthHelp = \xs, acc -> + when xs is + Nil -> acc + Cons _ rest -> lengthHelp rest (1 + acc) + +safe : I64, I64, ConsList I64 -> Bool +safe = \queen, diagonal, xs -> + when xs is + Nil -> + True + + Cons q t -> + queen != q && queen != q + diagonal && queen != q - diagonal && safe queen (diagonal + 1) t + +appendSafe : I64, ConsList I64, ConsList (ConsList I64) -> ConsList (ConsList I64) +appendSafe = \k, soln, solns -> + if k <= 0 then + solns + else if safe k 1 soln then + appendSafe (k - 1) soln (Cons (Cons k soln) solns) + else + appendSafe (k - 1) soln solns + +extend = \n, acc, solns -> + when solns is + Nil -> acc + Cons soln rest -> extend n (appendSafe n soln acc) rest + +findSolutions = \n, k -> + if k == 0 then + Cons Nil Nil + + else + extend n Nil (findSolutions n (k - 1))