diff --git a/Cargo.lock b/Cargo.lock index d2f1756c6c..6cee71c44b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1081,7 +1081,18 @@ checksum = "fc587bc0ec293155d5bfa6b9891ec18a1e330c234f896ea47fbada4cadbe47e6" dependencies = [ "cfg-if 0.1.10", "libc", - "wasi", + "wasi 0.9.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4060f4657be78b8e766215b02b18a2e862d83745545de804638e2b545e81aee6" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "wasi 0.10.1+wasi-snapshot-preview1", ] [[package]] @@ -2267,13 +2278,25 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" dependencies = [ - "getrandom", + "getrandom 0.1.15", "libc", "rand_chacha 0.2.2", "rand_core 0.5.1", "rand_hc 0.2.0", ] +[[package]] +name = "rand" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18519b42a40024d661e1714153e9ad0c3de27cd495760ceb09710920f1098b1e" +dependencies = [ + "libc", + "rand_chacha 0.3.0", + "rand_core 0.6.1", + "rand_hc 0.3.0", +] + [[package]] name = "rand_chacha" version = "0.1.1" @@ -2294,6 +2317,16 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rand_chacha" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e12735cf05c9e10bf21534da50a147b924d555dc7a547c42e6bb2d5b6017ae0d" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.1", +] + [[package]] name = "rand_core" version = "0.3.1" @@ -2315,7 +2348,16 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" dependencies = [ - "getrandom", + "getrandom 0.1.15", +] + +[[package]] +name = "rand_core" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c026d7df8b298d90ccbbc5190bd04d85e159eaf5576caeacf8741da93ccbd2e5" +dependencies = [ + "getrandom 0.2.1", ] [[package]] @@ -2336,6 +2378,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rand_hc" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3190ef7066a446f2e7f42e239d161e905420ccab01eb967c9eb27d21b2322a73" +dependencies = [ + "rand_core 0.6.1", +] + [[package]] name = "rand_isaac" version = "0.1.1" @@ -2459,7 +2510,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de0737333e7a9502c789a36d7c7fa6092a49895d4faa31ca5df163857ded2e9d" dependencies = [ - "getrandom", + "getrandom 0.1.15", "redox_syscall", "rust-argon2", ] @@ -2678,6 +2729,7 @@ dependencies = [ "pretty_assertions", "quickcheck", "quickcheck_macros", + "rand 0.8.2", "roc_builtins", "roc_can", "roc_collections", @@ -3727,6 +3779,12 @@ version = "0.9.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" +[[package]] +name = "wasi" +version = "0.10.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93c6c3420963c5c64bca373b25e77acb562081b9bb4dd5bb864187742186cea9" + [[package]] name = "wasm-bindgen" version = "0.2.69" diff --git a/Cargo.toml b/Cargo.toml index 777e51cf01..2d986f0009 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,6 @@ members = [ [profile.release] lto = "fat" codegen-units = 1 -debug = true # enable when profiling +# debug = true # enable when profiling diff --git a/cli/src/repl/eval.rs b/cli/src/repl/eval.rs index f96d4686ce..dfc23a74aa 100644 --- a/cli/src/repl/eval.rs +++ b/cli/src/repl/eval.rs @@ -224,6 +224,7 @@ fn jit_to_ast_help<'a>( Layout::Union(UnionLayout::Recursive(_)) | Layout::Union(UnionLayout::NullableWrapped { .. }) | Layout::Union(UnionLayout::NullableUnwrapped { .. }) + | Layout::Union(UnionLayout::NonNullableUnwrapped(_)) | Layout::RecursivePointer => { todo!("add support for rendering recursive tag unions in the REPL") } @@ -305,6 +306,9 @@ fn ptr_to_ast<'a>( let (tag_name, payload_vars) = tags.iter().next().unwrap(); single_tag_union_to_ast(env, ptr, field_layouts, tag_name.clone(), payload_vars) } + Content::Structure(FlatType::EmptyRecord) => { + struct_to_ast(env, ptr, &[], &MutMap::default()) + } other => { unreachable!( "Something had a Struct layout, but instead of a Record type, it had: {:?}", diff --git a/cli/tests/cli_run.rs b/cli/tests/cli_run.rs index c62608b414..b2d90626ff 100644 --- a/cli/tests/cli_run.rs +++ b/cli/tests/cli_run.rs @@ -178,7 +178,7 @@ mod cli_run { &example_file("benchmarks", "NQueens.roc"), "nqueens", &[], - "724\n", + "4\n", false, ); } @@ -207,6 +207,30 @@ mod cli_run { ); } + #[test] + #[serial(deriv)] + fn run_rbtree_insert_not_optimized() { + check_output( + &example_file("benchmarks", "RBTreeInsert.roc"), + "rbtree-insert", + &[], + "Node Black 0 {} Empty Empty\n", + false, + ); + } + + #[test] + #[serial(deriv)] + fn run_rbtree_delete_not_optimized() { + check_output( + &example_file("benchmarks", "RBTreeDel.roc"), + "rbtree-del", + &[], + "30\n", + false, + ); + } + // #[test] // #[serial(effect)] // fn run_effect_unoptimized() { diff --git a/cli/tests/repl_eval.rs b/cli/tests/repl_eval.rs index 57a1dfc4d4..fa36b83885 100644 --- a/cli/tests/repl_eval.rs +++ b/cli/tests/repl_eval.rs @@ -199,6 +199,11 @@ mod repl_eval { expect_success("[]", "[] : List *"); } + #[test] + fn literal_empty_list_empty_record() { + expect_success("[ {} ]", "[ {} ] : List {}"); + } + #[test] fn literal_num_list() { expect_success("[ 1, 2, 3 ]", "[ 1, 2, 3 ] : List (Num *)"); diff --git a/compiler/builtins/bitcode/src/str.zig b/compiler/builtins/bitcode/src/str.zig index 9a9f07db1d..7bc214f28d 100644 --- a/compiler/builtins/bitcode/src/str.zig +++ b/compiler/builtins/bitcode/src/str.zig @@ -831,20 +831,14 @@ fn strConcatHelp(allocator: *Allocator, comptime T: type, result_in_place: InPla var result = RocStr.initBig(allocator, T, result_in_place, combined_length); { - const old_if_small = &@bitCast([16]u8, arg1); - const old_if_big = @ptrCast([*]u8, arg1.str_bytes); - const old_bytes = if (arg1.isSmallStr()) old_if_small else old_if_big; - - const new_bytes: [*]u8 = @ptrCast([*]u8, result.str_bytes); + const old_bytes = arg1.asU8ptr(); + const new_bytes = @ptrCast([*]u8, result.str_bytes); @memcpy(new_bytes, old_bytes, arg1.len()); } { - const old_if_small = &@bitCast([16]u8, arg2); - const old_if_big = @ptrCast([*]u8, arg2.str_bytes); - const old_bytes = if (arg2.isSmallStr()) old_if_small else old_if_big; - + const old_bytes = arg2.asU8ptr(); const new_bytes = @ptrCast([*]u8, result.str_bytes) + arg1.len(); @memcpy(new_bytes, old_bytes, arg2.len()); diff --git a/compiler/can/src/def.rs b/compiler/can/src/def.rs index aeefb303fe..82c9fb71c8 100644 --- a/compiler/can/src/def.rs +++ b/compiler/can/src/def.rs @@ -789,8 +789,19 @@ fn canonicalize_pending_def<'a>( let arity = typ.arity(); + let problem = match &loc_can_pattern.value { + Pattern::Identifier(symbol) => RuntimeError::NoImplementationNamed { + def_symbol: *symbol, + }, + Pattern::Shadowed(region, loc_ident) => RuntimeError::Shadowing { + original_region: *region, + shadow: loc_ident.clone(), + }, + _ => RuntimeError::NoImplementation, + }; + // Fabricate a body for this annotation, that will error at runtime - let value = Expr::RuntimeError(RuntimeError::NoImplementation); + let value = Expr::RuntimeError(problem); let is_closure = arity > 0; let loc_can_expr = if !is_closure { Located { diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 3d79900341..c7c3b31c59 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -805,8 +805,11 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( Tag { union_size, arguments, + tag_layout, .. - } if *union_size == 1 => { + } if *union_size == 1 + && matches!(tag_layout, Layout::Union(UnionLayout::NonRecursive(_))) => + { let it = arguments.iter(); let ctx = env.context; @@ -1037,6 +1040,83 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( data_ptr.into() } + Tag { + arguments, + tag_layout: Layout::Union(UnionLayout::NonNullableUnwrapped(fields)), + union_size, + tag_id, + .. + } => { + debug_assert_eq!(*union_size, 1); + debug_assert_eq!(*tag_id, 0); + debug_assert_eq!(arguments.len(), fields.len()); + + let struct_layout = + Layout::Union(UnionLayout::NonRecursive(env.arena.alloc([*fields]))); + + let ptr_size = env.ptr_bytes; + let ctx = env.context; + let builder = env.builder; + + // Determine types + let num_fields = arguments.len() + 1; + let mut field_types = Vec::with_capacity_in(num_fields, env.arena); + let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); + + for (field_symbol, tag_field_layout) in arguments.iter().zip(fields.iter()) { + let val = load_symbol(env, scope, field_symbol); + + // Zero-sized fields have no runtime representation. + // The layout of the struct expects them to be dropped! + if !tag_field_layout.is_dropped_because_empty() { + let field_type = + basic_type_from_layout(env.arena, env.context, tag_field_layout, ptr_size); + + field_types.push(field_type); + + if let Layout::RecursivePointer = tag_field_layout { + debug_assert!(val.is_pointer_value()); + + // we store recursive pointers as `i64*` + let ptr = env.builder.build_bitcast( + val, + ctx.i64_type().ptr_type(AddressSpace::Generic), + "cast_recursive_pointer", + ); + + field_vals.push(ptr); + } else { + // this check fails for recursive tag unions, but can be helpful while debugging + + field_vals.push(val); + } + } + } + + // Create the struct_type + let data_ptr = reserve_with_refcount(env, &struct_layout); + let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); + let struct_ptr = env + .builder + .build_bitcast( + data_ptr, + struct_type.ptr_type(AddressSpace::Generic), + "block_of_memory_to_tag", + ) + .into_pointer_value(); + + // Insert field exprs into struct_val + for (index, field_val) in field_vals.into_iter().enumerate() { + let field_ptr = builder + .build_struct_gep(struct_ptr, index as u32, "struct_gep") + .unwrap(); + + builder.build_store(field_ptr, field_val); + } + + data_ptr.into() + } + Tag { arguments, tag_layout: @@ -1250,24 +1330,12 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( index, structure, wrapped: Wrapped::SingleElementRecord, + field_layouts, .. } => { - match load_symbol_and_layout(env, scope, structure) { - (StructValue(argument), Layout::Struct(fields)) if fields.len() > 1 => - // TODO so sometimes a value gets Wrapped::SingleElementRecord - // but still has multiple fields... - { - env.builder - .build_extract_value( - argument, - *index as u32, - env.arena - .alloc(format!("struct_field_access_single_element{}", index)), - ) - .unwrap() - } - (other, _) => other, - } + debug_assert_eq!(field_layouts.len(), 1); + debug_assert_eq!(*index, 0); + load_symbol(env, scope, structure) } AccessAtIndex { @@ -1278,15 +1346,17 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( } => { // extract field from a record match load_symbol_and_layout(env, scope, structure) { - (StructValue(argument), Layout::Struct(fields)) if fields.len() > 1 => env - .builder - .build_extract_value( - argument, - *index as u32, - env.arena - .alloc(format!("struct_field_access_record_{}", index)), - ) - .unwrap(), + (StructValue(argument), Layout::Struct(fields)) => { + debug_assert!(fields.len() > 1); + env.builder + .build_extract_value( + argument, + *index as u32, + env.arena + .alloc(format!("struct_field_access_record_{}", index)), + ) + .unwrap() + } (StructValue(argument), Layout::Closure(_, _, _)) => env .builder .build_extract_value( @@ -1295,6 +1365,38 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( env.arena.alloc(format!("closure_field_access_{}_", index)), ) .unwrap(), + ( + PointerValue(argument), + Layout::Union(UnionLayout::NonNullableUnwrapped(fields)), + ) => { + let struct_layout = Layout::Struct(fields); + let struct_type = basic_type_from_layout( + env.arena, + env.context, + &struct_layout, + env.ptr_bytes, + ); + + let cast_argument = env + .builder + .build_bitcast( + argument, + struct_type.ptr_type(AddressSpace::Generic), + "cast_rosetree_like", + ) + .into_pointer_value(); + + let ptr = env + .builder + .build_struct_gep( + cast_argument, + *index as u32, + env.arena.alloc(format!("non_nullable_unwrapped_{}", index)), + ) + .unwrap(); + + env.builder.build_load(ptr, "load_rosetree_like") + } (other, layout) => unreachable!( "can only index into struct layout\nValue: {:?}\nLayout: {:?}\nIndex: {:?}", other, layout, index @@ -1826,6 +1928,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( } else { basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes) }; + let alloca = create_entry_block_alloca( env, parent, @@ -2038,12 +2141,12 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( // This doesn't currently do anything context.i64_type().const_zero().into() } - Inc(symbol, cont) => { + Inc(symbol, inc_amount, cont) => { let (value, layout) = load_symbol_and_layout(env, scope, symbol); let layout = layout.clone(); if layout.contains_refcounted() { - increment_refcount_layout(env, parent, layout_ids, value, &layout); + increment_refcount_layout(env, parent, layout_ids, *inc_amount, value, &layout); } build_exp_stmt(env, layout_ids, scope, parent, cont) @@ -2176,24 +2279,34 @@ pub fn complex_bitcast<'ctx>( ) -> BasicValueEnum<'ctx> { use inkwell::types::BasicType; - // we can't use the more simple // builder.build_bitcast(from_value, to_type, "cast_basic_basic") // because this does not allow some (valid) bitcasts - // store the value in memory - let argument_pointer = builder.build_alloca(from_value.get_type(), "cast_alloca"); - builder.build_store(argument_pointer, from_value); + use BasicTypeEnum::*; + match (from_value.get_type(), to_type) { + (PointerType(_), PointerType(_)) => { + // we can't use the more straightforward bitcast in all cases + // it seems like a bitcast only works on integers and pointers + // and crucially does not work not on arrays + builder.build_bitcast(from_value, to_type, name) + } + _ => { + // store the value in memory + let argument_pointer = builder.build_alloca(from_value.get_type(), "cast_alloca"); + builder.build_store(argument_pointer, from_value); - // then read it back as a different type - let to_type_pointer = builder - .build_bitcast( - argument_pointer, - to_type.ptr_type(inkwell::AddressSpace::Generic), - name, - ) - .into_pointer_value(); + // then read it back as a different type + let to_type_pointer = builder + .build_bitcast( + argument_pointer, + to_type.ptr_type(inkwell::AddressSpace::Generic), + name, + ) + .into_pointer_value(); - builder.build_load(to_type_pointer, "cast_value") + builder.build_load(to_type_pointer, "cast_value") + } + } } fn extract_tag_discriminant_struct<'a, 'ctx, 'env>( @@ -2319,6 +2432,7 @@ fn build_switch_ir<'a, 'ctx, 'env>( debug_assert!(cond_value.is_pointer_value()); extract_tag_discriminant_ptr(env, cond_value.into_pointer_value()) } + NonNullableUnwrapped(_) => unreachable!("there is no tag to switch on"), NullableWrapped { nullable_id, .. } => { // we match on the discriminant, not the whole Tag cond_layout = Layout::Builtin(Builtin::Int64); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 0600170390..480dc34705 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -576,7 +576,7 @@ pub fn list_get_unsafe<'a, 'ctx, 'env>( let result = builder.build_load(elem_ptr, "List.get"); - increment_refcount_layout(env, parent, layout_ids, result, elem_layout); + increment_refcount_layout(env, parent, layout_ids, 1, result, elem_layout); result } @@ -1369,11 +1369,11 @@ macro_rules! list_map_help { let list_ptr = load_list_ptr(builder, list_wrapper, ptr_type); let list_loop = |index, before_elem| { - increment_refcount_layout($env, parent, layout_ids, before_elem, elem_layout); + increment_refcount_layout($env, parent, layout_ids, 1, before_elem, elem_layout); let arguments = match closure_info { Some((closure_data_layout, closure_data)) => { - increment_refcount_layout( $env, parent, layout_ids, closure_data, closure_data_layout); + increment_refcount_layout( $env, parent, layout_ids, 1, closure_data, closure_data_layout); bumpalo::vec![in $env.arena; before_elem, closure_data] } diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index c4229d97d2..eb6686ef4f 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -146,6 +146,10 @@ pub fn basic_type_from_layout<'ctx>( let block = block_of_memory_slices(context, &[&other_fields[1..]], ptr_bytes); block.ptr_type(AddressSpace::Generic).into() } + NonNullableUnwrapped(fields) => { + let block = block_of_memory_slices(context, &[fields], ptr_bytes); + block.ptr_type(AddressSpace::Generic).into() + } NonRecursive(_) => block_of_memory(context, layout, ptr_bytes), } } diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index 8aaa7f5314..7ddbfd3e9f 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -1,6 +1,6 @@ use crate::llvm::build::{ - cast_basic_basic, cast_block_of_memory_to_tag, create_entry_block_alloca, set_name, Env, Scope, - FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64, + cast_basic_basic, cast_block_of_memory_to_tag, set_name, Env, FAST_CALL_CONV, + LLVM_SADD_WITH_OVERFLOW_I64, }; use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list}; use crate::llvm::convert::{ @@ -101,7 +101,19 @@ impl<'ctx> PointerToRefcount<'ctx> { env.builder.build_store(self.value, refcount); } - fn increment<'a, 'env>(&self, env: &Env<'a, 'ctx, 'env>) { + fn modify<'a, 'env>( + &self, + mode: CallMode<'ctx>, + layout: &Layout<'a>, + env: &Env<'a, 'ctx, 'env>, + ) { + match mode { + CallMode::Inc(_, inc_amount) => self.increment(inc_amount, env), + CallMode::Dec => self.decrement(env, layout), + } + } + + fn increment<'a, 'env>(&self, amount: IntValue<'ctx>, env: &Env<'a, 'ctx, 'env>) { let refcount = self.get_refcount(env); let builder = env.builder; let refcount_type = ptr_int(env.context, env.ptr_bytes); @@ -112,11 +124,7 @@ impl<'ctx> PointerToRefcount<'ctx> { refcount_type.const_int(REFCOUNT_MAX as u64, false), "refcount_max_check", ); - let incremented = builder.build_int_add( - refcount, - refcount_type.const_int(1_u64, false), - "increment_refcount", - ); + let incremented = builder.build_int_add(refcount, amount, "increment_refcount"); let new_refcount = builder .build_select(max, refcount, incremented, "select_refcount") @@ -164,6 +172,7 @@ impl<'ctx> PointerToRefcount<'ctx> { env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); + let call = env .builder .build_call(function, &[refcount_ptr.into()], fn_name); @@ -284,12 +293,13 @@ impl<'ctx> PointerToRefcount<'ctx> { } } -pub fn decrement_refcount_struct<'a, 'ctx, 'env>( +fn modify_refcount_struct<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, value: BasicValueEnum<'ctx>, layouts: &[Layout<'a>], + mode: Mode, ) { let wrapper_struct = value.into_struct_value(); @@ -300,11 +310,29 @@ pub fn decrement_refcount_struct<'a, 'ctx, 'env>( .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") .unwrap(); - decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout) + modify_refcount_layout(env, parent, layout_ids, mode, field_ptr, field_layout); } } } +pub fn increment_refcount_layout<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + layout_ids: &mut LayoutIds<'a>, + inc_amount: u64, + value: BasicValueEnum<'ctx>, + layout: &Layout<'a>, +) { + modify_refcount_layout( + env, + parent, + layout_ids, + Mode::Inc(inc_amount), + value, + layout, + ); +} + pub fn decrement_refcount_layout<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, @@ -312,83 +340,14 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>( value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { - use Layout::*; - - match layout { - Builtin(builtin) => { - decrement_refcount_builtin(env, parent, layout_ids, value, layout, builtin) - } - Closure(_, closure_layout, _) => { - if closure_layout.contains_refcounted() { - let wrapper_struct = value.into_struct_value(); - - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, 1, "decrement_closure_data") - .unwrap(); - - decrement_refcount_layout( - env, - parent, - layout_ids, - field_ptr, - &closure_layout.as_block_of_memory_layout(), - ) - } - } - PhantomEmptyStruct => {} - - Struct(layouts) => { - decrement_refcount_struct(env, parent, layout_ids, value, layouts); - } - RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"), - - Union(variant) => { - use UnionLayout::*; - - match variant { - NonRecursive(tags) => { - build_dec_union(env, layout_ids, tags, value); - } - - NullableWrapped { - other_tags: tags, .. - } => { - debug_assert!(value.is_pointer_value()); - - build_dec_rec_union(env, layout_ids, tags, value.into_pointer_value(), true); - } - - NullableUnwrapped { other_fields, .. } => { - debug_assert!(value.is_pointer_value()); - - let other_fields = &other_fields[1..]; - - build_dec_rec_union( - env, - layout_ids, - &*env.arena.alloc([other_fields]), - value.into_pointer_value(), - true, - ); - } - - Recursive(tags) => { - debug_assert!(value.is_pointer_value()); - build_dec_rec_union(env, layout_ids, tags, value.into_pointer_value(), false); - } - } - } - - FunctionPointer(_, _) | Pointer(_) => {} - } + modify_refcount_layout(env, parent, layout_ids, Mode::Dec, value, layout); } -#[inline(always)] -fn decrement_refcount_builtin<'a, 'ctx, 'env>( +fn modify_refcount_builtin<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, builtin: &Builtin<'a>, @@ -406,7 +365,7 @@ fn decrement_refcount_builtin<'a, 'ctx, 'env>( let (len, ptr) = load_list(env.builder, wrapper_struct, ptr_type); let loop_fn = |_index, element| { - decrement_refcount_layout(env, parent, layout_ids, element, element_layout); + modify_refcount_layout(env, parent, layout_ids, mode, element, element_layout); }; incrementing_elem_loop( @@ -415,13 +374,13 @@ fn decrement_refcount_builtin<'a, 'ctx, 'env>( parent, ptr, len, - "dec_index", + "modify_rc_index", loop_fn, ); } if let MemoryMode::Refcounted = memory_mode { - build_dec_list(env, layout_ids, layout, wrapper_struct); + modify_refcount_list(env, layout_ids, mode, layout, wrapper_struct); } } Set(element_layout) => { @@ -439,15 +398,19 @@ fn decrement_refcount_builtin<'a, 'ctx, 'env>( } Str => { let wrapper_struct = value.into_struct_value(); - build_dec_str(env, layout_ids, layout, wrapper_struct); + modify_refcount_str(env, layout_ids, mode, layout, wrapper_struct); + } + _ => { + debug_assert!(!builtin.is_refcounted()); } - _ => {} } } -pub fn increment_refcount_layout<'a, 'ctx, 'env>( + +fn modify_refcount_layout<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { @@ -455,7 +418,7 @@ pub fn increment_refcount_layout<'a, 'ctx, 'env>( match layout { Builtin(builtin) => { - increment_refcount_builtin(env, parent, layout_ids, value, layout, builtin) + modify_refcount_builtin(env, parent, layout_ids, mode, value, layout, builtin) } Union(variant) => { @@ -467,7 +430,14 @@ pub fn increment_refcount_layout<'a, 'ctx, 'env>( } => { debug_assert!(value.is_pointer_value()); - build_inc_rec_union(env, layout_ids, tags, value.into_pointer_value(), true); + build_rec_union( + env, + layout_ids, + mode, + tags, + value.into_pointer_value(), + true, + ); } NullableUnwrapped { other_fields, .. } => { @@ -475,23 +445,42 @@ pub fn increment_refcount_layout<'a, 'ctx, 'env>( let other_fields = &other_fields[1..]; - build_inc_rec_union( + build_rec_union( env, layout_ids, + mode, &*env.arena.alloc([other_fields]), value.into_pointer_value(), true, ); } + NonNullableUnwrapped(fields) => { + debug_assert!(value.is_pointer_value()); + + build_rec_union( + env, + layout_ids, + mode, + &*env.arena.alloc([*fields]), + value.into_pointer_value(), + true, + ); + } + Recursive(tags) => { debug_assert!(value.is_pointer_value()); - build_inc_rec_union(env, layout_ids, tags, value.into_pointer_value(), false); + build_rec_union( + env, + layout_ids, + mode, + tags, + value.into_pointer_value(), + false, + ); } - NonRecursive(tags) => { - build_inc_union(env, layout_ids, tags, value); - } + NonRecursive(tags) => modify_refcount_union(env, layout_ids, mode, tags, value), } } Closure(_, closure_layout, _) => { @@ -503,90 +492,44 @@ pub fn increment_refcount_layout<'a, 'ctx, 'env>( .build_extract_value(wrapper_struct, 1, "increment_closure_data") .unwrap(); - increment_refcount_layout( + modify_refcount_layout( env, parent, layout_ids, + mode, field_ptr, &closure_layout.as_block_of_memory_layout(), ) } } - _ => {} + + Struct(layouts) => { + modify_refcount_struct(env, parent, layout_ids, value, layouts, mode); + } + + PhantomEmptyStruct => {} + + RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"), + + FunctionPointer(_, _) | Pointer(_) => {} } } -#[inline(always)] -fn increment_refcount_builtin<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - parent: FunctionValue<'ctx>, - layout_ids: &mut LayoutIds<'a>, - value: BasicValueEnum<'ctx>, - layout: &Layout<'a>, - builtin: &Builtin<'a>, -) { - use Builtin::*; - - match builtin { - List(memory_mode, element_layout) => { - let wrapper_struct = value.into_struct_value(); - if element_layout.contains_refcounted() { - let ptr_type = - basic_type_from_layout(env.arena, env.context, element_layout, env.ptr_bytes) - .ptr_type(AddressSpace::Generic); - - let (len, ptr) = load_list(env.builder, wrapper_struct, ptr_type); - - let loop_fn = |_index, element| { - increment_refcount_layout(env, parent, layout_ids, element, element_layout); - }; - - incrementing_elem_loop( - env.builder, - env.context, - parent, - ptr, - len, - "inc_index", - loop_fn, - ); - } - - if let MemoryMode::Refcounted = memory_mode { - build_inc_list(env, layout_ids, layout, wrapper_struct); - } - } - Set(element_layout) => { - if element_layout.contains_refcounted() { - // TODO decrement all values - } - todo!(); - } - Dict(key_layout, value_layout) => { - if key_layout.contains_refcounted() || value_layout.contains_refcounted() { - // TODO decrement all values - } - - todo!(); - } - Str => { - let wrapper_struct = value.into_struct_value(); - build_inc_str(env, layout_ids, layout, wrapper_struct); - } - _ => {} - } -} - -pub fn build_inc_list<'a, 'ctx, 'env>( +fn modify_refcount_list<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, layout: &Layout<'a>, original_wrapper: StructValue<'ctx>, ) { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let symbol = Symbol::INC; + let (call_name, symbol) = match mode { + Mode::Inc(_) => ("increment_list", Symbol::INC), + Mode::Dec => ("decrement_list", Symbol::DEC), + }; + let fn_name = layout_ids .get(symbol, &layout) .to_symbol_string(symbol, &env.interns); @@ -595,9 +538,9 @@ pub fn build_inc_list<'a, 'ctx, 'env>( Some(function_value) => function_value, None => { let basic_type = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); - build_inc_list_help(env, layout_ids, layout, function_value); + modify_refcount_list_help(env, mode, layout, function_value); function_value } @@ -606,16 +549,20 @@ pub fn build_inc_list<'a, 'ctx, 'env>( env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[original_wrapper.into()], "increment_list"); - call.set_call_convention(FAST_CALL_CONV); + call_help(env, function, mode, original_wrapper.into(), call_name); } -fn build_inc_list_help<'a, 'ctx, 'env>( +fn mode_to_call_mode(function: FunctionValue<'_>, mode: Mode) -> CallMode<'_> { + match mode { + Mode::Dec => CallMode::Dec, + Mode::Inc(num) => CallMode::Inc(num, function.get_nth_param(1).unwrap().into_int_value()), + } +} + +fn modify_refcount_list_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - _layout_ids: &mut LayoutIds<'a>, + mode: Mode, layout: &Layout<'a>, fn_val: FunctionValue<'ctx>, ) { @@ -644,25 +591,12 @@ fn build_inc_list_help<'a, 'ctx, 'env>( ); builder.set_current_debug_location(&ctx, loc); - let mut scope = Scope::default(); - // Add args to scope let arg_symbol = Symbol::ARG_1; let arg_val = fn_val.get_param_iter().next().unwrap(); set_name(arg_val, arg_symbol.ident_string(&env.interns)); - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - let parent = fn_val; let original_wrapper = arg_val.into_struct_value(); @@ -676,15 +610,16 @@ fn build_inc_list_help<'a, 'ctx, 'env>( ); // build blocks - let increment_block = ctx.append_basic_block(parent, "increment_block"); - let cont_block = ctx.append_basic_block(parent, "after_increment_block"); + let modification_block = ctx.append_basic_block(parent, "modification_block"); + let cont_block = ctx.append_basic_block(parent, "modify_rc_list_cont"); - builder.build_conditional_branch(is_non_empty, increment_block, cont_block); + builder.build_conditional_branch(is_non_empty, modification_block, cont_block); - builder.position_at_end(increment_block); + builder.position_at_end(modification_block); let refcount_ptr = PointerToRefcount::from_list_wrapper(env, original_wrapper); - refcount_ptr.increment(env); + let call_mode = mode_to_call_mode(fn_val, mode); + refcount_ptr.modify(call_mode, layout, env); builder.build_unconditional_branch(cont_block); @@ -694,16 +629,21 @@ fn build_inc_list_help<'a, 'ctx, 'env>( builder.build_return(None); } -pub fn build_dec_list<'a, 'ctx, 'env>( +fn modify_refcount_str<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, layout: &Layout<'a>, original_wrapper: StructValue<'ctx>, ) { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let symbol = Symbol::DEC; + let (call_name, symbol) = match mode { + Mode::Inc(_) => ("increment_str", Symbol::INC), + Mode::Dec => ("decrement_str", Symbol::DEC), + }; + let fn_name = layout_ids .get(symbol, &layout) .to_symbol_string(symbol, &env.interns); @@ -712,9 +652,9 @@ pub fn build_dec_list<'a, 'ctx, 'env>( Some(function_value) => function_value, None => { let basic_type = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); - build_dec_list_help(env, layout_ids, layout, function_value); + modify_refcount_str_help(env, mode, layout, function_value); function_value } @@ -723,15 +663,13 @@ pub fn build_dec_list<'a, 'ctx, 'env>( env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[original_wrapper.into()], "decrement_list"); - call.set_call_convention(FAST_CALL_CONV); + + call_help(env, function, mode, original_wrapper.into(), call_name); } -fn build_dec_list_help<'a, 'ctx, 'env>( +fn modify_refcount_str_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - _layout_ids: &mut LayoutIds<'a>, + mode: Mode, layout: &Layout<'a>, fn_val: FunctionValue<'ctx>, ) { @@ -760,145 +698,12 @@ fn build_dec_list_help<'a, 'ctx, 'env>( ); builder.set_current_debug_location(&ctx, loc); - let mut scope = Scope::default(); - // Add args to scope let arg_symbol = Symbol::ARG_1; let arg_val = fn_val.get_param_iter().next().unwrap(); set_name(arg_val, arg_symbol.ident_string(&env.interns)); - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - - let parent = fn_val; - - // the block we'll always jump to when we're done - let cont_block = ctx.append_basic_block(parent, "after_decrement_block_build_dec_list_help"); - let decrement_block = ctx.append_basic_block(parent, "decrement_block"); - - // currently, an empty list has a null-pointer in its length is 0 - // so we must first check the length - - let original_wrapper = arg_val.into_struct_value(); - - let len = list_len(builder, original_wrapper); - let is_non_empty = builder.build_int_compare( - IntPredicate::UGT, - len, - ctx.i64_type().const_zero(), - "len > 0", - ); - - // if the length is 0, we're done and jump to the continuation block - // otherwise, actually read and check the refcount - builder.build_conditional_branch(is_non_empty, decrement_block, cont_block); - builder.position_at_end(decrement_block); - - let refcount_ptr = PointerToRefcount::from_list_wrapper(env, original_wrapper); - refcount_ptr.decrement(env, layout); - - env.builder.build_unconditional_branch(cont_block); - - builder.position_at_end(cont_block); - - // this function returns void - builder.build_return(None); -} - -pub fn build_inc_str<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - - let symbol = Symbol::INC; - let fn_name = layout_ids - .get(symbol, &layout) - .to_symbol_string(symbol, &env.interns); - - let function = match env.module.get_function(fn_name.as_str()) { - Some(function_value) => function_value, - None => { - let basic_type = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); - - build_inc_str_help(env, layout_ids, layout, function_value); - - function_value - } - }; - - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[original_wrapper.into()], "increment_str"); - call.set_call_convention(FAST_CALL_CONV); -} - -fn build_inc_str_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - _layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - fn_val: FunctionValue<'ctx>, -) { - let builder = env.builder; - let ctx = env.context; - - // Add a basic block for the entry point - let entry = ctx.append_basic_block(fn_val, "entry"); - - builder.position_at_end(entry); - - let func_scope = fn_val.get_subprogram().unwrap(); - let lexical_block = env.dibuilder.create_lexical_block( - /* scope */ func_scope.as_debug_info_scope(), - /* file */ env.compile_unit.get_file(), - /* line_no */ 0, - /* column_no */ 0, - ); - - let loc = env.dibuilder.create_debug_location( - ctx, - /* line */ 0, - /* column */ 0, - /* current_scope */ lexical_block.as_debug_info_scope(), - /* inlined_at */ None, - ); - builder.set_current_debug_location(&ctx, loc); - - let mut scope = Scope::default(); - - // Add args to scope - let arg_symbol = Symbol::ARG_1; - let arg_val = fn_val.get_param_iter().next().unwrap(); - - set_name(arg_val, arg_symbol.ident_string(&env.interns)); - - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - let parent = fn_val; let str_wrapper = arg_val.into_struct_value(); @@ -913,138 +718,19 @@ fn build_inc_str_help<'a, 'ctx, 'env>( IntPredicate::SGT, len, ptr_int(ctx, env.ptr_bytes).const_zero(), - "len > 0", + "is_big_str", ); // the block we'll always jump to when we're done - let cont_block = ctx.append_basic_block(parent, "after_increment_block"); - let decrement_block = ctx.append_basic_block(parent, "increment_block"); + let cont_block = ctx.append_basic_block(parent, "modify_rc_str_cont"); + let modification_block = ctx.append_basic_block(parent, "modify_rc"); - builder.build_conditional_branch(is_big_and_non_empty, decrement_block, cont_block); - builder.position_at_end(decrement_block); + builder.build_conditional_branch(is_big_and_non_empty, modification_block, cont_block); + builder.position_at_end(modification_block); let refcount_ptr = PointerToRefcount::from_list_wrapper(env, str_wrapper); - refcount_ptr.increment(env); - - builder.build_unconditional_branch(cont_block); - - builder.position_at_end(cont_block); - - // this function returns void - builder.build_return(None); -} - -pub fn build_dec_str<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - - let symbol = Symbol::DEC; - let fn_name = layout_ids - .get(symbol, &layout) - .to_symbol_string(symbol, &env.interns); - - let function = match env.module.get_function(fn_name.as_str()) { - Some(function_value) => function_value, - None => { - let basic_type = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); - - build_dec_str_help(env, layout_ids, layout, function_value); - - function_value - } - }; - - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[original_wrapper.into()], "decrement_str"); - call.set_call_convention(FAST_CALL_CONV); -} - -fn build_dec_str_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - _layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - fn_val: FunctionValue<'ctx>, -) { - let builder = env.builder; - let ctx = env.context; - - // Add a basic block for the entry point - let entry = ctx.append_basic_block(fn_val, "entry"); - - builder.position_at_end(entry); - - let func_scope = fn_val.get_subprogram().unwrap(); - let lexical_block = env.dibuilder.create_lexical_block( - /* scope */ func_scope.as_debug_info_scope(), - /* file */ env.compile_unit.get_file(), - /* line_no */ 0, - /* column_no */ 0, - ); - - let loc = env.dibuilder.create_debug_location( - ctx, - /* line */ 0, - /* column */ 0, - /* current_scope */ lexical_block.as_debug_info_scope(), - /* inlined_at */ None, - ); - builder.set_current_debug_location(&ctx, loc); - - let mut scope = Scope::default(); - - // Add args to scope - let arg_symbol = Symbol::ARG_1; - let arg_val = fn_val.get_param_iter().next().unwrap(); - - set_name(arg_val, arg_symbol.ident_string(&env.interns)); - - let alloca = create_entry_block_alloca( - env, - fn_val, - arg_val.get_type(), - arg_symbol.ident_string(&env.interns), - ); - - builder.build_store(alloca, arg_val); - - scope.insert(arg_symbol, (layout.clone(), alloca)); - - let parent = fn_val; - - let str_wrapper = arg_val.into_struct_value(); - let len = builder - .build_extract_value(str_wrapper, Builtin::WRAPPER_LEN, "read_str_ptr") - .unwrap() - .into_int_value(); - - // Small strings have 1 as the first bit of length, making them negative. - // Thus, to check for big and non empty, just needs a signed len > 0. - let is_big_and_non_empty = builder.build_int_compare( - IntPredicate::SGT, - len, - ptr_int(ctx, env.ptr_bytes).const_zero(), - "len > 0", - ); - - // the block we'll always jump to when we're done - let cont_block = ctx.append_basic_block(parent, "after_decrement_block_build_dec_str_help"); - let decrement_block = ctx.append_basic_block(parent, "decrement_block"); - - builder.build_conditional_branch(is_big_and_non_empty, decrement_block, cont_block); - builder.position_at_end(decrement_block); - - let refcount_ptr = PointerToRefcount::from_list_wrapper(env, str_wrapper); - refcount_ptr.decrement(env, layout); + let call_mode = mode_to_call_mode(fn_val, mode); + refcount_ptr.modify(call_mode, layout, env); builder.build_unconditional_branch(cont_block); @@ -1055,12 +741,21 @@ fn build_dec_str_help<'a, 'ctx, 'env>( } /// Build an increment or decrement function for a specific layout -pub fn build_header<'a, 'ctx, 'env>( +fn build_header<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, arg_type: BasicTypeEnum<'ctx>, + mode: Mode, fn_name: &str, ) -> FunctionValue<'ctx> { - build_header_help(env, fn_name, env.context.void_type().into(), &[arg_type]) + match mode { + Mode::Inc(_) => build_header_help( + env, + fn_name, + env.context.void_type().into(), + &[arg_type, ptr_int(env.context, env.ptr_bytes).into()], + ), + Mode::Dec => build_header_help(env, fn_name, env.context.void_type().into(), &[arg_type]), + } } /// Build an increment or decrement function for a specific layout @@ -1097,19 +792,33 @@ pub fn build_header_help<'a, 'ctx, 'env>( fn_val } -pub fn build_dec_rec_union<'a, 'ctx, 'env>( +#[derive(Clone, Copy)] +enum Mode { + Inc(u64), + Dec, +} + +#[derive(Clone, Copy)] +enum CallMode<'ctx> { + Inc(u64, IntValue<'ctx>), + Dec, +} + +fn build_rec_union<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, fields: &'a [&'a [Layout<'a>]], value: PointerValue<'ctx>, is_nullable: bool, ) { let layout = Layout::Union(UnionLayout::Recursive(fields)); - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); + let (call_name, symbol) = match mode { + Mode::Inc(_) => ("increment_rec_union", Symbol::INC), + Mode::Dec => ("decrement_rec_union", Symbol::DEC), + }; - let symbol = Symbol::DEC; let fn_name = layout_ids .get(symbol, &layout) .to_symbol_string(symbol, &env.interns); @@ -1117,31 +826,31 @@ pub fn build_dec_rec_union<'a, 'ctx, 'env>( let function = match env.module.get_function(fn_name.as_str()) { Some(function_value) => function_value, None => { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + let basic_type = block_of_memory_slices(env.context, fields, env.ptr_bytes) .ptr_type(AddressSpace::Generic) .into(); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); - build_dec_rec_union_help(env, layout_ids, fields, function_value, is_nullable); + build_rec_union_help(env, layout_ids, mode, fields, function_value, is_nullable); + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); function_value } }; - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - - let call = env - .builder - .build_call(function, &[value.into()], "decrement_union"); - - call.set_call_convention(FAST_CALL_CONV); + call_help(env, function, mode, value.into(), call_name); } -pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( +fn build_rec_union_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, tags: &[&[Layout<'a>]], fn_val: FunctionValue<'ctx>, is_nullable: bool, @@ -1151,6 +860,8 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( let context = &env.context; let builder = env.builder; + let pick = |a, b| if let Mode::Inc(_) = mode { a } else { b }; + // Add a basic block for the entry point let entry = context.append_basic_block(fn_val, "entry"); @@ -1187,6 +898,25 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( debug_assert!(arg_val.is_pointer_value()); let value_ptr = arg_val.into_pointer_value(); + // branches that are not/don't contain anything refcounted + // if there is only one branch, we don't need to switch + let switch_needed: bool = (|| { + for field_layouts in tags.iter() { + // if none of the fields are or contain anything refcounted, just move on + if !field_layouts + .iter() + .any(|x| x.is_refcounted() || x.contains_refcounted()) + { + return true; + } + } + false + })(); + + // to increment/decrement the cons-cell itself + let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); + let call_mode = mode_to_call_mode(fn_val, mode); + let ctx = env.context; let cont_block = ctx.append_basic_block(parent, "cont"); if is_nullable { @@ -1211,8 +941,6 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( // next, make a jump table for all possible values of the tag_id let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - let merge_block = env.context.append_basic_block(parent, "decrement_merge"); - builder.set_current_debug_location(&context, loc); for (tag_id, field_layouts) in tags.iter().enumerate() { @@ -1224,7 +952,10 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( continue; } - let block = env.context.append_basic_block(parent, "tag_id_decrement"); + let block = env + .context + .append_basic_block(parent, pick("tag_id_increment", "tag_id_decrement")); + env.builder.position_at_end(block); let wrapper_type = basic_type_from_layout( @@ -1244,6 +975,11 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( ) .into_pointer_value(); + // defer actually performing the refcount modifications until after the current cell has + // been decremented, see below + let mut deferred_rec = Vec::new_in(env.arena); + let mut deferred_nonrec = Vec::new_in(env.arena); + for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { // this field has type `*i64`, but is really a pointer to the data we want @@ -1266,33 +1002,43 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( union_type.ptr_type(AddressSpace::Generic).into(), ); - // recursively decrement the field - let call = env.builder.build_call( - fn_val, - &[recursive_field_ptr], - "recursive_tag_decrement", - ); - - // Because it's an internal-only function, use the fast calling convention. - call.set_call_convention(FAST_CALL_CONV); + deferred_rec.push(recursive_field_ptr); } else if field_layout.contains_refcounted() { - // TODO this loads the whole field onto the stack; - // that's wasteful if e.g. the field is a big record, where only - // some fields are actually refcounted. let elem_pointer = env .builder .build_struct_gep(struct_ptr, i as u32, "gep_recursive_pointer") .unwrap(); - let field = env - .builder - .build_load(elem_pointer, "decrement_struct_field"); + let field = env.builder.build_load( + elem_pointer, + pick("increment_struct_field", "decrement_struct_field"), + ); - decrement_refcount_layout(env, parent, layout_ids, field, field_layout); + deferred_nonrec.push((field, field_layout)); } } - env.builder.build_unconditional_branch(merge_block); + // OPTIMIZATION + // + // We really would like `inc/dec` to be tail-recursive; it gives roughly a 2X speedup on linked + // lists. To achieve it, we must first load all fields that we want to inc/dec (done above) + // and store them on the stack, then modify (and potentially free) the current cell, then + // actually inc/dec the fields. + refcount_ptr.modify(call_mode, &layout, env); + + for (field, field_layout) in deferred_nonrec { + modify_refcount_layout(env, parent, layout_ids, mode, field, field_layout); + } + + let call_name = pick("recursive_tag_increment", "recursive_tag_decrement"); + for ptr in deferred_rec { + // recursively decrement the field + let call = call_help(env, fn_val, mode, ptr, call_name); + call.set_tail_call(true); + } + + // this function returns void + builder.build_return(None); cases.push(( env.context.i64_type().const_int(tag_id as u64, false), @@ -1304,227 +1050,33 @@ pub fn build_dec_rec_union_help<'a, 'ctx, 'env>( env.builder.position_at_end(cont_block); - // read the tag_id - let current_tag_id = rec_union_read_tag(env, value_ptr); + if cases.len() == 1 && !switch_needed { + // there is only one tag in total; we don't need a switch + // this is essential for nullable unwrapped layouts, + // because the `else` branch below would try to read its + // (nonexistant) tag id + let (_, only_branch) = cases.pop().unwrap(); + env.builder.build_unconditional_branch(only_branch); + } else { + // read the tag_id + let current_tag_id = rec_union_read_tag(env, value_ptr); - // switch on it - env.builder - .build_switch(current_tag_id, merge_block, &cases); - - env.builder.position_at_end(merge_block); - - // decrement this cons-cell itself - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); - refcount_ptr.decrement(env, &layout); - - // this function returns void - builder.build_return(None); -} - -pub fn build_dec_union<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - fields: &'a [&'a [Layout<'a>]], - value: BasicValueEnum<'ctx>, -) { - let layout = Layout::Union(UnionLayout::NonRecursive(fields)); - - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - - let symbol = Symbol::DEC; - let fn_name = layout_ids - .get(symbol, &layout) - .to_symbol_string(symbol, &env.interns); - - let function = match env.module.get_function(fn_name.as_str()) { - Some(function_value) => function_value, - None => { - let basic_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); - - build_dec_union_help(env, layout_ids, fields, function_value); - - function_value - } - }; - - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - - let call = env - .builder - .build_call(function, &[value], "decrement_union"); - - call.set_call_convention(FAST_CALL_CONV); -} - -pub fn build_dec_union_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - tags: &[&[Layout<'a>]], - fn_val: FunctionValue<'ctx>, -) { - debug_assert!(!tags.is_empty()); - - let context = &env.context; - let builder = env.builder; - - // Add a basic block for the entry point - let entry = context.append_basic_block(fn_val, "entry"); - - builder.position_at_end(entry); - - let func_scope = fn_val.get_subprogram().unwrap(); - let lexical_block = env.dibuilder.create_lexical_block( - /* scope */ func_scope.as_debug_info_scope(), - /* file */ env.compile_unit.get_file(), - /* line_no */ 0, - /* column_no */ 0, - ); - - let loc = env.dibuilder.create_debug_location( - context, - /* line */ 0, - /* column */ 0, - /* current_scope */ lexical_block.as_debug_info_scope(), - /* inlined_at */ None, - ); - builder.set_current_debug_location(&context, loc); - - // Add args to scope - let arg_symbol = Symbol::ARG_1; - - let arg_val = fn_val.get_param_iter().next().unwrap(); - - set_name(arg_val, arg_symbol.ident_string(&env.interns)); - - let parent = fn_val; - - let before_block = env.builder.get_insert_block().expect("to be in a function"); - - debug_assert!(arg_val.is_struct_value()); - let wrapper_struct = arg_val.into_struct_value(); - - // next, make a jump table for all possible values of the tag_id - let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - - let merge_block = env.context.append_basic_block(parent, "decrement_merge"); - - builder.set_current_debug_location(&context, loc); - - for (tag_id, field_layouts) in tags.iter().enumerate() { - // if none of the fields are or contain anything refcounted, just move on - if !field_layouts - .iter() - .any(|x| x.is_refcounted() || x.contains_refcounted()) - { - continue; - } - - let block = env.context.append_basic_block(parent, "tag_id_decrement"); - env.builder.position_at_end(block); - - let wrapper_type = basic_type_from_layout( - env.arena, - env.context, - &Layout::Struct(field_layouts), - env.ptr_bytes, - ); - - debug_assert!(wrapper_type.is_struct_type()); - let wrapper_struct = cast_block_of_memory_to_tag(env.builder, wrapper_struct, wrapper_type); - - for (i, field_layout) in field_layouts.iter().enumerate() { - if let Layout::RecursivePointer = field_layout { - panic!("a non-recursive tag union cannot contain RecursivePointer"); - } else if field_layout.contains_refcounted() { - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") - .unwrap(); - - decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout); - } - } - - env.builder.build_unconditional_branch(merge_block); - - cases.push(( - env.context.i64_type().const_int(tag_id as u64, false), - block, - )); - } - - cases.reverse(); - - env.builder.position_at_end(before_block); - - // read the tag_id - let current_tag_id = { - // the first element of the wrapping struct is an array of i64 - let first_array = env - .builder - .build_extract_value(wrapper_struct, 0, "read_tag_id") - .unwrap() - .into_array_value(); + let merge_block = env + .context + .append_basic_block(parent, pick("increment_merge", "decrement_merge")); + // switch on it env.builder - .build_extract_value(first_array, 0, "read_tag_id_2") - .unwrap() - .into_int_value() - }; + .build_switch(current_tag_id, merge_block, &cases); - // switch on it - env.builder - .build_switch(current_tag_id, merge_block, &cases); + env.builder.position_at_end(merge_block); - env.builder.position_at_end(merge_block); + // increment/decrement the cons-cell itself + refcount_ptr.modify(call_mode, &layout, env); - // this function returns void - builder.build_return(None); -} - -pub fn build_inc_rec_union<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - fields: &'a [&'a [Layout<'a>]], - value: PointerValue<'ctx>, - is_nullable: bool, -) { - let layout = Layout::Union(UnionLayout::Recursive(fields)); - - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - - let symbol = Symbol::INC; - let fn_name = layout_ids - .get(symbol, &layout) - .to_symbol_string(symbol, &env.interns); - - let function = match env.module.get_function(fn_name.as_str()) { - Some(function_value) => function_value, - None => { - let basic_type = block_of_memory_slices(env.context, fields, env.ptr_bytes) - .ptr_type(AddressSpace::Generic) - .into(); - let function_value = build_header(env, basic_type, &fn_name); - - build_inc_rec_union_help(env, layout_ids, fields, function_value, is_nullable); - - function_value - } - }; - - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[value.into()], "increment_union"); - - call.set_call_convention(FAST_CALL_CONV); + // this function returns void + builder.build_return(None); + } } fn rec_union_read_tag<'a, 'ctx, 'env>( @@ -1544,178 +1096,32 @@ fn rec_union_read_tag<'a, 'ctx, 'env>( .into_int_value() } -pub fn build_inc_rec_union_help<'a, 'ctx, 'env>( +fn call_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - tags: &[&[Layout<'a>]], - fn_val: FunctionValue<'ctx>, - is_nullable: bool, -) { - debug_assert!(!tags.is_empty()); + function: FunctionValue<'ctx>, + mode: Mode, + value: BasicValueEnum<'ctx>, + call_name: &str, +) -> inkwell::values::CallSiteValue<'ctx> { + let call = match mode { + Mode::Inc(inc_amount) => { + let rc_increment = ptr_int(env.context, env.ptr_bytes).const_int(inc_amount, false); - let context = &env.context; - let builder = env.builder; - - // Add a basic block for the entry point - let entry = context.append_basic_block(fn_val, "entry"); - - builder.position_at_end(entry); - - let func_scope = fn_val.get_subprogram().unwrap(); - let lexical_block = env.dibuilder.create_lexical_block( - /* scope */ func_scope.as_debug_info_scope(), - /* file */ env.compile_unit.get_file(), - /* line_no */ 0, - /* column_no */ 0, - ); - - let loc = env.dibuilder.create_debug_location( - context, - /* line */ 0, - /* column */ 0, - /* current_scope */ lexical_block.as_debug_info_scope(), - /* inlined_at */ None, - ); - builder.set_current_debug_location(&context, loc); - - // Add args to scope - let arg_symbol = Symbol::ARG_1; - let arg_val = fn_val.get_param_iter().next().unwrap(); - - set_name(arg_val, arg_symbol.ident_string(&env.interns)); - - let parent = fn_val; - - debug_assert!(arg_val.is_pointer_value()); - let value_ptr = arg_val.into_pointer_value(); - - let ctx = env.context; - let cont_block = ctx.append_basic_block(parent, "cont"); - if is_nullable { - let is_null = env.builder.build_is_null(value_ptr, "is_null"); - - let then_block = ctx.append_basic_block(parent, "then"); - - env.builder.build_switch( - is_null, - cont_block, - &[(ctx.bool_type().const_int(1, false), then_block)], - ); - - { - env.builder.position_at_end(then_block); - env.builder.build_return(None); + env.builder + .build_call(function, &[value, rc_increment.into()], call_name) } - } else { - env.builder.build_unconditional_branch(cont_block); - } + Mode::Dec => env.builder.build_call(function, &[value], call_name), + }; - // next, make a jump table for all possible values of the tag_id - let mut cases = Vec::with_capacity_in(tags.len(), env.arena); + call.set_call_convention(FAST_CALL_CONV); - let merge_block = env.context.append_basic_block(parent, "increment_merge"); - - for (tag_id, field_layouts) in tags.iter().enumerate() { - // if none of the fields are or contain anything refcounted, just move on - if !field_layouts - .iter() - .any(|x| x.is_refcounted() || x.contains_refcounted()) - { - continue; - } - - let block = env.context.append_basic_block(parent, "tag_id_increment"); - env.builder.position_at_end(block); - - let wrapper_type = basic_type_from_layout( - env.arena, - env.context, - &Layout::Struct(field_layouts), - env.ptr_bytes, - ); - - // cast the opaque pointer to a pointer of the correct shape - let struct_ptr = env - .builder - .build_bitcast( - value_ptr, - wrapper_type.ptr_type(AddressSpace::Generic), - "opaque_to_correct", - ) - .into_pointer_value(); - - for (i, field_layout) in field_layouts.iter().enumerate() { - if let Layout::RecursivePointer = field_layout { - // this field has type `*i64`, but is really a pointer to the data we want - let elem_pointer = env - .builder - .build_struct_gep(struct_ptr, i as u32, "gep_recursive_pointer") - .unwrap(); - - let ptr_as_i64_ptr = env - .builder - .build_load(elem_pointer, "load_recursive_pointer"); - - debug_assert!(ptr_as_i64_ptr.is_pointer_value()); - - // therefore we must cast it to our desired type - let union_type = block_of_memory_slices(env.context, tags, env.ptr_bytes); - let recursive_field_ptr = env.builder.build_bitcast( - ptr_as_i64_ptr, - union_type.ptr_type(AddressSpace::Generic), - "recursive_to_desired", - ); - - // recursively increment the field - let call = env.builder.build_call( - fn_val, - &[recursive_field_ptr], - "recursive_tag_increment", - ); - - // Because it's an internal-only function, use the fast calling convention. - call.set_call_convention(FAST_CALL_CONV); - } else if field_layout.contains_refcounted() { - let elem_pointer = env - .builder - .build_struct_gep(struct_ptr, i as u32, "gep_field") - .unwrap(); - - let field = env.builder.build_load(elem_pointer, "load_field"); - - increment_refcount_layout(env, parent, layout_ids, field, field_layout); - } - } - - env.builder.build_unconditional_branch(merge_block); - - cases.push((env.context.i8_type().const_int(tag_id as u64, false), block)); - } - - env.builder.position_at_end(cont_block); - - // read the tag_id - let tag_id = rec_union_read_tag(env, value_ptr); - - let tag_id_u8 = env - .builder - .build_int_cast(tag_id, env.context.i8_type(), "tag_id_u8"); - - env.builder.build_switch(tag_id_u8, merge_block, &cases); - - env.builder.position_at_end(merge_block); - - // increment this cons cell - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); - refcount_ptr.increment(env); - - // this function returns void - builder.build_return(None); + call } -pub fn build_inc_union<'a, 'ctx, 'env>( +fn modify_refcount_union<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, fields: &'a [&'a [Layout<'a>]], value: BasicValueEnum<'ctx>, ) { @@ -1724,7 +1130,11 @@ pub fn build_inc_union<'a, 'ctx, 'env>( let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let symbol = Symbol::INC; + let (call_name, symbol) = match mode { + Mode::Inc(_) => ("increment_union", Symbol::INC), + Mode::Dec => ("decrement_union", Symbol::DEC), + }; + let fn_name = layout_ids .get(symbol, &layout) .to_symbol_string(symbol, &env.interns); @@ -1733,9 +1143,9 @@ pub fn build_inc_union<'a, 'ctx, 'env>( Some(function_value) => function_value, None => { let basic_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let function_value = build_header(env, basic_type, &fn_name); + let function_value = build_header(env, basic_type, mode, &fn_name); - build_inc_union_help(env, layout_ids, fields, function_value); + modify_refcount_union_help(env, layout_ids, mode, fields, function_value); function_value } @@ -1744,16 +1154,14 @@ pub fn build_inc_union<'a, 'ctx, 'env>( env.builder.position_at_end(block); env.builder .set_current_debug_location(env.context, di_location); - let call = env - .builder - .build_call(function, &[value], "increment_union"); - call.set_call_convention(FAST_CALL_CONV); + call_help(env, function, mode, value, call_name); } -pub fn build_inc_union_help<'a, 'ctx, 'env>( +fn modify_refcount_union_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, + mode: Mode, tags: &[&[Layout<'a>]], fn_val: FunctionValue<'ctx>, ) { @@ -1792,7 +1200,6 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( let parent = fn_val; - let layout = Layout::Union(UnionLayout::Recursive(tags)); let before_block = env.builder.get_insert_block().expect("to be in a function"); let wrapper_struct = arg_val.into_struct_value(); @@ -1819,7 +1226,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( // next, make a jump table for all possible values of the tag_id let mut cases = Vec::with_capacity_in(tags.len(), env.arena); - let merge_block = env.context.append_basic_block(parent, "increment_merge"); + let merge_block = env + .context + .append_basic_block(parent, "modify_rc_union_merge"); for (tag_id, field_layouts) in tags.iter().enumerate() { // if none of the fields are or contain anything refcounted, just move on @@ -1830,7 +1239,7 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( continue; } - let block = env.context.append_basic_block(parent, "tag_id_increment"); + let block = env.context.append_basic_block(parent, "tag_id_modify"); env.builder.position_at_end(block); let wrapper_type = basic_type_from_layout( @@ -1845,48 +1254,14 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>( for (i, field_layout) in field_layouts.iter().enumerate() { if let Layout::RecursivePointer = field_layout { - // this field has type `*i64`, but is really a pointer to the data we want - let ptr_as_i64_ptr = env - .builder - .build_extract_value(wrapper_struct, i as u32, "increment_struct_field") - .unwrap(); - - debug_assert!(ptr_as_i64_ptr.is_pointer_value()); - - // therefore we must cast it to our desired type - let union_type = block_of_memory(env.context, &layout, env.ptr_bytes); - let recursive_field_ptr = env - .builder - .build_bitcast( - ptr_as_i64_ptr, - union_type.ptr_type(AddressSpace::Generic), - "recursive_to_desired", - ) - .into_pointer_value(); - - let recursive_field = env - .builder - .build_load(recursive_field_ptr, "load_recursive_field"); - - // recursively increment the field - let call = - env.builder - .build_call(fn_val, &[recursive_field], "recursive_tag_increment"); - - // Because it's an internal-only function, use the fast calling convention. - call.set_call_convention(FAST_CALL_CONV); - - // TODO do this decrement before the recursive call? - // Then the recursive call is potentially TCE'd - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, recursive_field_ptr); - refcount_ptr.increment(env); + panic!("non-recursive tag unions cannot contain naked recursion pointers!"); } else if field_layout.contains_refcounted() { let field_ptr = env .builder - .build_extract_value(wrapper_struct, i as u32, "increment_struct_field") + .build_extract_value(wrapper_struct, i as u32, "modify_tag_field") .unwrap(); - increment_refcount_layout(env, parent, layout_ids, field_ptr, field_layout); + modify_refcount_layout(env, parent, layout_ids, mode, field_ptr, field_layout); } } diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index cb01dc7760..403cdbee33 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -27,6 +27,11 @@ mod gen_list { assert_evals_to!("[]", RocList::from_slice(&[]), RocList); } + #[test] + fn list_literal_empty_record() { + assert_evals_to!("[{}]", RocList::from_slice(&[()]), RocList<()>); + } + #[test] fn int_singleton_list_literal() { assert_evals_to!("[1, 2]", RocList::from_slice(&[1, 2]), RocList); diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 89409319b5..25eb406032 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -13,6 +13,8 @@ mod helpers; #[cfg(test)] mod gen_primitives { + use roc_std::RocStr; + #[test] fn basic_int() { assert_evals_to!("123", 123, i64); @@ -1268,7 +1270,6 @@ mod gen_primitives { } #[test] - #[ignore] fn rbtree_insert() { assert_non_opt_evals_to!( indoc!( @@ -1338,54 +1339,20 @@ mod gen_primitives { _ -> Node color key value left right - main : RedBlackTree (Int *) {} + show : RedBlackTree I64 {} -> Str + show = \tree -> + when tree is + Empty -> "Empty" + Node _ _ _ _ _ -> "Node" + + + main : Str main = - insert 0 {} Empty + show (insert 0 {} Empty) "# ), - 1, - i64 - ); - } - - #[test] - #[ignore] - fn rbtree_balance_inc_dec() { - // TODO does not define a variable correctly, but all is well with the type signature - assert_non_opt_evals_to!( - indoc!( - r#" - app "test" provides [ main ] to "./platform" - - NodeColor : [ Red, Black ] - - RedBlackTree k : [ Node NodeColor k (RedBlackTree k) (RedBlackTree k), Empty ] - - # balance : NodeColor, k, RedBlackTree k, RedBlackTree k -> RedBlackTree k - balance = \color, key, left, right -> - when right is - Node Red rK rLeft rRight -> - when left is - Node Red _ _ _ -> - Node - Red - key - Empty - Empty - - _ -> - Node color rK (Node Red key left rLeft) rRight - - _ -> - Empty - - main : RedBlackTree (Int *) - main = - balance Red 0 Empty Empty - "# - ), - 0, - i64 + RocStr::from_slice("Node".as_bytes()), + RocStr ); } @@ -1413,6 +1380,47 @@ mod gen_primitives { ); } + #[test] + #[ignore] + fn rbtree_layout_issue() { + // there is a flex var in here somewhere that blows up layout creation + assert_non_opt_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + NodeColor : [ Red, Black ] + + RedBlackTree k v : [ Node NodeColor k v (RedBlackTree k v) (RedBlackTree k v), Empty ] + + # balance : NodeColor, k, v, RedBlackTree k v -> RedBlackTree k v + balance = \color, key, value, right -> + when right is + Node Red _ _ rLeft rRight -> + Node color key value rLeft rRight + + + _ -> + Empty + + show : RedBlackTree * * -> Str + show = \tree -> + when tree is + Empty -> "Empty" + Node _ _ _ _ _ -> "Node" + + zero : I64 + zero = 0 + + main : Str + main = show (balance Red zero zero Empty) + "# + ), + RocStr::from_slice("Empty".as_bytes()), + RocStr + ); + } + #[test] #[ignore] fn rbtree_balance_mono_problem() { @@ -1450,13 +1458,19 @@ mod gen_primitives { _ -> Empty - main : RedBlackTree (Int *) (Int *) - main = - balance Red 0 0 Empty Empty + show : RedBlackTree * * -> Str + show = \tree -> + when tree is + Empty -> "Empty" + Node _ _ _ _ _ -> "Node" + + + main : Str + main = show (balance Red 0 0 Empty Empty) "# ), - 1, - i64 + RocStr::from_slice("Empty".as_bytes()), + RocStr ); } @@ -1922,29 +1936,23 @@ mod gen_primitives { } #[test] - #[ignore] fn rosetree_basic() { assert_non_opt_evals_to!( indoc!( r#" app "test" provides [ main ] to "./platform" - # RoseTree Tree a : [ Tree a (List (Tree a)) ] - tree : a, List (Tree a) -> Tree a - tree = \a, t -> Tree a t - singleton : a -> Tree a singleton = \x -> Tree x [] main : Bool main = - x : I64 - x = 1 - - when tree x [ singleton 5, singleton 3 ] is - Tree 0x1 _ -> True + x : Tree F64 + x = singleton 3 + when x is + Tree 3.0 _ -> True _ -> False "# ), @@ -2117,4 +2125,36 @@ mod gen_primitives { i64 ); } + + #[test] + fn multiple_increment() { + // the `leaf` value will be incremented multiple times at once + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + Color : [ Red, Black ] + + Tree a b : [ Leaf, Node Color (Tree a b) a b (Tree a b) ] + + Map : Tree I64 Bool + + main : I64 + main = + leaf : Map + leaf = Leaf + + m : Map + m = Node Black (Node Black leaf 10 False leaf) 11 False (Node Black leaf 12 False (Node Red leaf 13 False leaf)) + + when m is + Leaf -> 0 + Node _ _ _ _ _ -> 1 + "# + ), + 1, + i64 + ); + } } diff --git a/compiler/gen/tests/gen_tags.rs b/compiler/gen/tests/gen_tags.rs index 6f7d0674d7..0f547ed3e2 100644 --- a/compiler/gen/tests/gen_tags.rs +++ b/compiler/gen/tests/gen_tags.rs @@ -960,4 +960,27 @@ mod gen_tags { |x: &i64| *x ); } + + #[test] + fn newtype_wrapper() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + ConsList a : [ Nil, Cons a (ConsList a) ] + + foo : ConsList I64 -> ConsList I64 + foo = \t -> + when Delmin (Del t 0.0) is + Delmin (Del ry _) -> Cons 42 ry + + main = foo Nil + "# + ), + 42, + &i64, + |x: &i64| *x + ); + } } diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index 76801146ad..b242b68ec5 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -68,15 +68,6 @@ pub fn helper<'a>( debug_assert_eq!(exposed_to_host.len(), 1); let main_fn_symbol = exposed_to_host.keys().copied().next().unwrap(); - let (_, main_fn_layout) = match procedures.keys().find(|(s, _)| *s == main_fn_symbol) { - Some(found) => found.clone(), - None => panic!( - "The main function symbol {:?} does not have a procedure in {:?}", - main_fn_symbol, - &procedures.keys() - ), - }; - let mut lines = Vec::new(); // errors whose reporting we delay (so we can see that code gen generates runtime errors) let mut delayed_errors = Vec::new(); @@ -160,6 +151,15 @@ pub fn helper<'a>( } } + let (_, main_fn_layout) = match procedures.keys().find(|(s, _)| *s == main_fn_symbol) { + Some(found) => found.clone(), + None => panic!( + "The main function symbol {:?} does not have a procedure in {:?}", + main_fn_symbol, + &procedures.keys() + ), + }; + let module = roc_gen::llvm::build::module_from_builtins(context, "app"); // strip Zig debug stuff diff --git a/compiler/gen_dev/src/lib.rs b/compiler/gen_dev/src/lib.rs index 8e6bcfa094..0d1a5bf2ed 100644 --- a/compiler/gen_dev/src/lib.rs +++ b/compiler/gen_dev/src/lib.rs @@ -404,7 +404,7 @@ where self.set_last_seen(*sym, stmt); } Stmt::Rethrow => {} - Stmt::Inc(sym, following) => { + Stmt::Inc(sym, _inc, following) => { self.set_last_seen(*sym, stmt); self.scan_ast(following); } diff --git a/compiler/load/src/file.rs b/compiler/load/src/file.rs index a114ade91b..9fbef73072 100644 --- a/compiler/load/src/file.rs +++ b/compiler/load/src/file.rs @@ -1803,6 +1803,8 @@ fn update<'a>( if state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations { debug_assert!(work.is_empty(), "still work remaining {:?}", &work); + Proc::insert_refcount_operations(arena, &mut state.procedures); + // display the mono IR of the module, for debug purposes if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS { let procs_string = state @@ -1816,8 +1818,6 @@ fn update<'a>( println!("{}", result); } - Proc::insert_refcount_operations(arena, &mut state.procedures); - msg_tx .send(Msg::FinishedAllSpecialization { subs, diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 7d5312aca1..99f44d971e 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -168,7 +168,7 @@ impl<'a> ParamMap<'a> { stack.extend(branches.iter().map(|b| &b.1)); stack.push(default_branch); } - Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), + Inc(_, _, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), Ret(_) | Rethrow | Jump(_, _) | RuntimeError(_) => { // these are terminal, do nothing @@ -513,7 +513,7 @@ impl<'a> BorrowInfState<'a> { } self.collect_stmt(default_branch); } - Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), + Inc(_, _, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), Ret(_) | RuntimeError(_) | Rethrow => { // these are terminal, do nothing diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index ac1e9bd827..f7e1f56b8c 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -602,7 +602,21 @@ fn to_relevant_branch_help<'a>( } } Wrapped::RecordOrSingleTagUnion => { - todo!("this should need a special index, right?") + let sub_positions = arguments.into_iter().enumerate().map( + |(index, (pattern, _))| { + ( + Path::Index { + index: index as u64, + tag_id, + path: Box::new(path.clone()), + }, + Guard::NoGuard, + pattern, + ) + }, + ); + start.extend(sub_positions); + start.extend(end); } Wrapped::MultiTagUnion => { let sub_positions = arguments.into_iter().enumerate().map( @@ -1013,6 +1027,8 @@ fn path_to_expr_help<'a>( debug_assert!(*index < field_layouts.len() as u64); + debug_assert_eq!(field_layouts.len(), 1); + let inner_layout = field_layouts[*index as usize].clone(); let inner_expr = Expr::AccessAtIndex { index: *index, @@ -1035,6 +1051,10 @@ fn path_to_expr_help<'a>( match variant { NonRecursive(layouts) | Recursive(layouts) => layouts[*tag_id as usize], + NonNullableUnwrapped(fields) => { + debug_assert_eq!(*tag_id, 0); + fields + } NullableWrapped { nullable_id, other_tags: layouts, diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index 27503a35e1..6b67823c86 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -52,7 +52,7 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet, MutSet) { Rethrow => {} - Inc(symbol, cont) | Dec(symbol, cont) => { + Inc(symbol, _, cont) | Dec(symbol, cont) => { result.insert(*symbol); stack.push(cont); } @@ -262,7 +262,9 @@ impl<'a> Context<'a> { } } - fn add_inc(&self, symbol: Symbol, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + fn add_inc(&self, symbol: Symbol, inc_amount: u64, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + debug_assert!(inc_amount > 0); + let info = self.get_var_info(symbol); if info.persistent { @@ -275,7 +277,7 @@ impl<'a> Context<'a> { return stmt; } - self.arena.alloc(Stmt::Inc(symbol, stmt)) + self.arena.alloc(Stmt::Inc(symbol, inc_amount, stmt)) } fn add_dec(&self, symbol: Symbol, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { @@ -335,11 +337,8 @@ impl<'a> Context<'a> { num_consumptions - 1 }; - // verify that this is indeed always 1 - debug_assert!(num_incs <= 1); - - if num_incs == 1 { - b = self.add_inc(*x, b) + if num_incs >= 1 { + b = self.add_inc(*x, num_incs as u64, b) } } } @@ -524,7 +523,7 @@ impl<'a> Context<'a> { let b = self.add_dec_if_needed(x, b, b_live_vars); let info_x = self.get_var_info(x); let b = if info_x.consume { - self.add_inc(z, b) + self.add_inc(z, 1, b) } else { b }; @@ -786,7 +785,7 @@ impl<'a> Context<'a> { live_vars.insert(*x); if info.reference && !info.consume { - (self.add_inc(*x, stmt), live_vars) + (self.add_inc(*x, 1, stmt), live_vars) } else { (stmt, live_vars) } @@ -851,7 +850,7 @@ impl<'a> Context<'a> { (switch, case_live_vars) } - RuntimeError(_) | Inc(_, _) | Dec(_, _) => (stmt, MutSet::default()), + RuntimeError(_) | Inc(_, _, _) | Dec(_, _) => (stmt, MutSet::default()), } } } @@ -902,7 +901,7 @@ pub fn collect_stmt( vars } - Inc(symbol, cont) | Dec(symbol, cont) => { + Inc(symbol, _, cont) | Dec(symbol, cont) => { vars.insert(*symbol); collect_stmt(cont, jp_live_vars, vars) } diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 2ff329e301..cd0e12d8a9 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -766,7 +766,7 @@ pub enum Stmt<'a> { }, Ret(Symbol), Rethrow, - Inc(Symbol, &'a Stmt<'a>), + Inc(Symbol, u64, &'a Stmt<'a>), Dec(Symbol, &'a Stmt<'a>), Join { id: JoinPointId, @@ -834,6 +834,8 @@ impl Wrapped { }, _ => Some(Wrapped::MultiTagUnion), }, + NonNullableUnwrapped(_) => Some(Wrapped::RecordOrSingleTagUnion), + NullableWrapped { .. } | NullableUnwrapped { .. } => { Some(Wrapped::MultiTagUnion) } @@ -1127,9 +1129,11 @@ impl<'a> Stmt<'a> { use Stmt::*; match self { - Let(symbol, expr, _, cont) => alloc + Let(symbol, expr, _layout, cont) => alloc .text("let ") .append(symbol_to_doc(alloc, *symbol)) + //.append(" : ") + //.append(alloc.text(format!("{:?}", layout))) .append(" = ") .append(expr.to_doc(alloc)) .append(";") @@ -1263,12 +1267,19 @@ impl<'a> Stmt<'a> { .append(alloc.intersperse(it, alloc.space())) .append(";") } - Inc(symbol, cont) => alloc + Inc(symbol, 1, cont) => alloc .text("inc ") .append(symbol_to_doc(alloc, *symbol)) .append(";") .append(alloc.hardline()) .append(cont.to_doc(alloc)), + Inc(symbol, n, cont) => alloc + .text("inc ") + .append(alloc.text(format!("{}", n))) + .append(symbol_to_doc(alloc, *symbol)) + .append(";") + .append(alloc.hardline()) + .append(cont.to_doc(alloc)), Dec(symbol, cont) => alloc .text("dec ") .append(symbol_to_doc(alloc, *symbol)) @@ -2731,6 +2742,7 @@ pub fn with_hole<'a>( use WrappedVariant::*; let (tag, layout) = match variant { Recursive { sorted_tag_layouts } => { + debug_assert!(sorted_tag_layouts.len() > 1); let tag_id_symbol = env.unique_symbol(); opt_tag_id_symbol = Some(tag_id_symbol); @@ -2751,6 +2763,7 @@ pub fn with_hole<'a>( layouts.push(arg_layouts); } + debug_assert!(layouts.len() > 1); let layout = Layout::Union(UnionLayout::Recursive(layouts.into_bump_slice())); @@ -2764,6 +2777,35 @@ pub fn with_hole<'a>( (tag, layout) } + NonNullableUnwrapped { + fields, + tag_name: wrapped_tag_name, + } => { + debug_assert_eq!(tag_name, wrapped_tag_name); + + opt_tag_id_symbol = None; + + field_symbols = { + let mut temp = + Vec::with_capacity_in(field_symbols_temp.len(), arena); + + temp.extend(field_symbols_temp.iter().map(|r| r.1)); + + temp.into_bump_slice() + }; + + let layout = Layout::Union(UnionLayout::NonNullableUnwrapped(fields)); + + let tag = Expr::Tag { + tag_layout: layout.clone(), + tag_name, + tag_id: tag_id as u8, + union_size, + arguments: field_symbols, + }; + + (tag, layout) + } NonRecursive { sorted_tag_layouts } => { let tag_id_symbol = env.unique_symbol(); opt_tag_id_symbol = Some(tag_id_symbol); @@ -3067,7 +3109,8 @@ pub fn with_hole<'a>( ); for (loc_cond, loc_then) in branches.into_iter().rev() { - let branching_symbol = env.unique_symbol(); + let branching_symbol = possible_reuse_symbol(env, procs, &loc_cond.value); + let then = with_hole( env, loc_then.value, @@ -3088,14 +3131,14 @@ pub fn with_hole<'a>( ); // add condition - stmt = with_hole( + stmt = assign_to_symbol( env, - loc_cond.value, - cond_var, procs, layout_cache, + cond_var, + loc_cond, branching_symbol, - env.arena.alloc(stmt), + stmt, ); } @@ -3254,7 +3297,10 @@ pub fn with_hole<'a>( match Wrapped::opt_from_layout(&record_layout) { Some(result) => result, - None => Wrapped::SingleElementRecord, + None => { + debug_assert_eq!(field_layouts.len(), 1); + Wrapped::SingleElementRecord + } } }; @@ -4662,8 +4708,8 @@ fn substitute_in_stmt_help<'a>( Some(s) => Some(arena.alloc(Ret(s))), None => None, }, - Inc(symbol, cont) => match substitute_in_stmt_help(arena, cont, subs) { - Some(cont) => Some(arena.alloc(Inc(*symbol, cont))), + Inc(symbol, inc, cont) => match substitute_in_stmt_help(arena, cont, subs) { + Some(cont) => Some(arena.alloc(Inc(*symbol, *inc, cont))), None => None, }, Dec(symbol, cont) => match substitute_in_stmt_help(arena, cont, subs) { @@ -5584,7 +5630,11 @@ fn call_by_name<'a>( partial_proc, ) { Ok((proc, layout)) => { - debug_assert_eq!(full_layout, layout); + debug_assert_eq!( + &full_layout, &layout, + "\n\n{:?}\n\n{:?}", + full_layout, layout + ); let function_layout = FunctionLayouts::from_layout(env.arena, layout); @@ -6035,7 +6085,12 @@ fn from_can_pattern_help<'a>( let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); - debug_assert_eq!(arguments.len(), argument_layouts[1..].len()); + debug_assert_eq!( + arguments.len(), + argument_layouts[1..].len(), + "{:?}", + tag_name + ); let it = argument_layouts[1..].iter(); for ((_, loc_pat), layout) in arguments.iter().zip(it) { @@ -6118,6 +6173,7 @@ fn from_can_pattern_help<'a>( temp }; + debug_assert!(layouts.len() > 1); let layout = Layout::Union(UnionLayout::Recursive(layouts.into_bump_slice())); @@ -6130,6 +6186,51 @@ fn from_can_pattern_help<'a>( } } + NonNullableUnwrapped { + tag_name: w_tag_name, + fields, + } => { + debug_assert_eq!(&w_tag_name, tag_name); + + ctors.push(Ctor { + tag_id: TagId(0_u8), + name: tag_name.clone(), + arity: fields.len(), + }); + + let union = crate::exhaustive::Union { + render_as: RenderAs::Tag, + alternatives: ctors, + }; + + let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena); + + debug_assert_eq!(arguments.len(), argument_layouts.len()); + let it = argument_layouts.iter(); + + for ((_, loc_pat), layout) in arguments.iter().zip(it) { + mono_args.push(( + from_can_pattern_help( + env, + layout_cache, + &loc_pat.value, + assignments, + )?, + layout.clone(), + )); + } + + let layout = Layout::Union(UnionLayout::NonNullableUnwrapped(fields)); + + Pattern::AppliedTag { + tag_name: tag_name.clone(), + tag_id: tag_id as u8, + arguments: mono_args, + union, + layout, + } + } + NullableWrapped { sorted_tag_layouts: ref tags, nullable_id, diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 3fc476911f..25ff4f61b9 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -43,8 +43,11 @@ pub enum UnionLayout<'a> { /// e.g. `Result a e : [ Ok a, Err e ]` NonRecursive(&'a [&'a [Layout<'a>]]), /// A recursive tag union - /// e.g. `RoseTree a : [ Tree a (List (RoseTree a)) ]` + /// e.g. `Expr : [ Sym Str, Add Expr Expr ]` Recursive(&'a [&'a [Layout<'a>]]), + /// A recursive tag union with just one constructor + /// e.g. `RoseTree a : [ Tree a (List (RoseTree a)) ]` + NonNullableUnwrapped(&'a [Layout<'a>]), /// A recursive tag union where the non-nullable variant(s) store the tag id /// e.g. `FingerTree a : [ Empty, Single a, More (Some a) (FingerTree (Tuple a)) (Some a) ]` /// see also: https://youtu.be/ip92VMpf_-A?t=164 @@ -485,7 +488,10 @@ impl<'a> Layout<'a> { NonRecursive(tags) => tags .iter() .all(|tag_layout| tag_layout.iter().all(|field| field.safe_to_memcpy())), - Recursive(_) | NullableWrapped { .. } | NullableUnwrapped { .. } => { + Recursive(_) + | NullableWrapped { .. } + | NullableUnwrapped { .. } + | NonNullableUnwrapped(_) => { // a recursive union will always contain a pointer, and is thus not safe to memcpy false } @@ -549,9 +555,10 @@ impl<'a> Layout<'a> { .max() .unwrap_or_default(), - Recursive(_) | NullableWrapped { .. } | NullableUnwrapped { .. } => { - pointer_size - } + Recursive(_) + | NullableWrapped { .. } + | NullableUnwrapped { .. } + | NonNullableUnwrapped(_) => pointer_size, } } Closure(_, closure_layout, _) => pointer_size + closure_layout.stack_size(pointer_size), @@ -580,9 +587,10 @@ impl<'a> Layout<'a> { .map(|x| x.alignment_bytes(pointer_size)) .max() .unwrap_or(0), - Recursive(_) | NullableWrapped { .. } | NullableUnwrapped { .. } => { - pointer_size - } + Recursive(_) + | NullableWrapped { .. } + | NullableUnwrapped { .. } + | NonNullableUnwrapped(_) => pointer_size, } } Layout::Builtin(builtin) => builtin.alignment_bytes(pointer_size), @@ -634,7 +642,10 @@ impl<'a> Layout<'a> { .map(|ls| ls.iter()) .flatten() .any(|f| f.contains_refcounted()), - Recursive(_) | NullableWrapped { .. } | NullableUnwrapped { .. } => true, + Recursive(_) + | NullableWrapped { .. } + | NullableUnwrapped { .. } + | NonNullableUnwrapped(_) => true, } } Closure(_, closure_layout, _) => closure_layout.contains_refcounted(), @@ -1116,6 +1127,9 @@ fn layout_from_flat_type<'a>( other_tags: many, }, } + } else if tag_layouts.len() == 1 { + // drop the tag id + UnionLayout::NonNullableUnwrapped(&tag_layouts.pop().unwrap()[1..]) } else { UnionLayout::Recursive(tag_layouts.into_bump_slice()) }; @@ -1220,6 +1234,10 @@ pub enum WrappedVariant<'a> { nullable_name: TagName, sorted_tag_layouts: Vec<'a, (TagName, &'a [Layout<'a>])>, }, + NonNullableUnwrapped { + tag_name: TagName, + fields: &'a [Layout<'a>], + }, NullableUnwrapped { nullable_id: bool, nullable_name: TagName, @@ -1281,6 +1299,7 @@ impl<'a> WrappedVariant<'a> { (!*nullable_id as u8, *other_fields) } } + NonNullableUnwrapped { fields, .. } => (0, fields), } } @@ -1299,6 +1318,7 @@ impl<'a> WrappedVariant<'a> { sorted_tag_layouts.len() + 1 } NullableUnwrapped { .. } => 2, + NonNullableUnwrapped { .. } => 1, } } } @@ -1409,6 +1429,11 @@ pub fn union_sorted_tags_help<'a>( } else { UnionVariant::Unit } + } else if opt_rec_var.is_some() { + UnionVariant::Wrapped(WrappedVariant::NonNullableUnwrapped { + tag_name, + fields: layouts.into_bump_slice(), + }) } else { UnionVariant::Unwrapped(layouts) } @@ -1517,6 +1542,7 @@ pub fn union_sorted_tags_help<'a>( } } } else if is_recursive { + debug_assert!(answer.len() > 1); WrappedVariant::Recursive { sorted_tag_layouts: answer, } @@ -1585,6 +1611,7 @@ pub fn layout_from_tag_union<'a>( let mut tag_layouts = Vec::with_capacity_in(tags.len(), arena); tag_layouts.extend(tags.iter().map(|r| r.1)); + debug_assert!(tag_layouts.len() > 1); Layout::Union(UnionLayout::Recursive(tag_layouts.into_bump_slice())) } @@ -1603,6 +1630,7 @@ pub fn layout_from_tag_union<'a>( } NullableUnwrapped { .. } => todo!(), + NonNullableUnwrapped { .. } => todo!(), } } } @@ -1778,8 +1806,8 @@ pub fn list_layout_from_elem<'a>( // If this was still a (List *) then it must have been an empty list Ok(Layout::Builtin(Builtin::EmptyList)) } - content => { - let elem_layout = Layout::new_help(env, elem_var, content)?; + _ => { + let elem_layout = Layout::from_var(env, elem_var)?; // This is a normal list. Ok(Layout::Builtin(Builtin::List( diff --git a/compiler/mono/src/tail_recursion.rs b/compiler/mono/src/tail_recursion.rs index 87aa45038c..c7e150d898 100644 --- a/compiler/mono/src/tail_recursion.rs +++ b/compiler/mono/src/tail_recursion.rs @@ -228,8 +228,8 @@ fn insert_jumps<'a>( None } } - Inc(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) { - Some(cont) => Some(arena.alloc(Inc(*symbol, cont))), + Inc(symbol, inc, cont) => match insert_jumps(arena, cont, goal_id, needle) { + Some(cont) => Some(arena.alloc(Inc(*symbol, *inc, cont))), None => None, }, Dec(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) { diff --git a/compiler/problem/src/can.rs b/compiler/problem/src/can.rs index 6cf91131fc..1173630fb3 100644 --- a/compiler/problem/src/can.rs +++ b/compiler/problem/src/can.rs @@ -145,6 +145,9 @@ pub enum RuntimeError { InvalidUnicodeCodePoint(Region), /// When the author specifies a type annotation but no implementation + NoImplementationNamed { + def_symbol: Symbol, + }, NoImplementation, /// cases where the `[]` value (or equivalently, `forall a. a`) pops up diff --git a/compiler/reporting/src/error/canonicalize.rs b/compiler/reporting/src/error/canonicalize.rs index e13c0d1c42..8c12467eb0 100644 --- a/compiler/reporting/src/error/canonicalize.rs +++ b/compiler/reporting/src/error/canonicalize.rs @@ -635,7 +635,7 @@ fn pretty_runtime_error<'b>( region ); } - RuntimeError::NoImplementation => todo!("no implementation, unreachable"), + RuntimeError::NoImplementation | RuntimeError::NoImplementationNamed { .. } => todo!("no implementation, unreachable"), RuntimeError::NonExhaustivePattern => { unreachable!("not currently reported (but can blow up at runtime)") } diff --git a/compiler/reporting/src/error/type.rs b/compiler/reporting/src/error/type.rs index 42ade1863e..c80e4088bd 100644 --- a/compiler/reporting/src/error/type.rs +++ b/compiler/reporting/src/error/type.rs @@ -888,12 +888,12 @@ fn add_category<'b>( If => alloc.concat(vec![ alloc.text("This "), alloc.keyword("if"), - alloc.text("expression produces:"), + alloc.text(" expression produces:"), ]), When => alloc.concat(vec![ alloc.text("This "), alloc.keyword("when"), - alloc.text("expression produces:"), + alloc.text(" expression produces:"), ]), List => alloc.concat(vec![this_is, alloc.text(" a list of type:")]), diff --git a/compiler/reporting/tests/test_reporting.rs b/compiler/reporting/tests/test_reporting.rs index 5b58ebe4d2..aeaf14b1d8 100644 --- a/compiler/reporting/tests/test_reporting.rs +++ b/compiler/reporting/tests/test_reporting.rs @@ -1129,7 +1129,7 @@ mod test_reporting { 3│> when True is 4│> _ -> 3.14 - This `when`expression produces: + This `when` expression produces: Float a diff --git a/compiler/types/src/solved_types.rs b/compiler/types/src/solved_types.rs index 5fe89bfc35..3fc267e6a4 100644 --- a/compiler/types/src/solved_types.rs +++ b/compiler/types/src/solved_types.rs @@ -371,7 +371,7 @@ impl SolvedType { match subs.get_without_compacting(var).content { FlexVar(_) => SolvedType::Flex(VarId::from_var(var, subs)), - RecursionVar { .. } => todo!(), + RecursionVar { .. } => SolvedType::Flex(VarId::from_var(var, subs)), RigidVar(name) => SolvedType::Rigid(name), Structure(flat_type) => Self::from_flat_type(subs, recursion_vars, flat_type), Alias(symbol, args, actual_var) => { diff --git a/editor/Cargo.toml b/editor/Cargo.toml index fd79dd0fb8..97445e2fcb 100644 --- a/editor/Cargo.toml +++ b/editor/Cargo.toml @@ -82,9 +82,14 @@ indoc = "0.3.3" quickcheck = "0.8" quickcheck_macros = "0.8" criterion = "0.3" +rand = "0.8.2" [[bench]] -name = "my_benchmark" +name = "file_benchmark" +harness = false + +[[bench]] +name = "edit_benchmark" harness = false # uncomment everything below if you have made changes to any shaders and diff --git a/editor/benches/edit_benchmark.rs b/editor/benches/edit_benchmark.rs new file mode 100644 index 0000000000..c8b404a378 --- /dev/null +++ b/editor/benches/edit_benchmark.rs @@ -0,0 +1,199 @@ +use bumpalo::Bump; +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use roc_editor::mvc::app_model::AppModel; +use roc_editor::mvc::ed_model::{EdModel, Position, RawSelection}; +use roc_editor::mvc::update::handle_new_char; +use roc_editor::text_buffer; +use roc_editor::text_buffer::TextBuffer; +use ropey::Rope; +use std::cmp::min; +use std::path::Path; + +// duplicate inside mvc::update +fn mock_app_model( + text_buf: TextBuffer, + caret_pos: Position, + selection_opt: Option, +) -> AppModel { + AppModel { + ed_model_opt: Some(EdModel { + text_buf, + caret_pos, + selection_opt, + glyph_dim_rect_opt: None, + has_focus: true, + }), + } +} + +fn text_buffer_from_str(lines_str: &str) -> TextBuffer { + TextBuffer { + text_rope: Rope::from_str(lines_str), + path_str: "".to_owned(), + } +} + +pub fn char_insert_bench(c: &mut Criterion) { + let text_buf = text_buffer_from_str(""); + + let caret_pos = Position { line: 0, column: 0 }; + + let selection_opt: Option = None; + let mut app_model = mock_app_model(text_buf, caret_pos, selection_opt); + c.bench_function("single char insert, small buffer", |b| { + b.iter(|| handle_new_char(&mut app_model, &'a')) + }); +} + +pub fn char_pop_bench(c: &mut Criterion) { + let nr_lines = 50000; + let mut text_buf = buf_from_dummy_file(nr_lines); + + let mut rand_gen_pos = StdRng::seed_from_u64(44); + + c.bench_function( + &format!("single char pop, {} lines", text_buf.nr_of_lines()), + |b| { + b.iter(|| { + let max_line_nr = text_buf.nr_of_lines(); + let rand_line_nr = rand_gen_pos.gen_range(0..max_line_nr); + let max_col = text_buf + .line_len(rand_line_nr) + .expect("Failed to retrieve line length."); + + let caret_pos = Position { + line: rand_line_nr, + column: rand_gen_pos.gen_range(0..max_col), + }; + + text_buf.pop_char(caret_pos); + }) + }, + ); +} + +fn get_all_lines_helper(nr_lines: usize, c: &mut Criterion) { + let text_buf = buf_from_dummy_file(nr_lines); + + let arena = Bump::new(); + + c.bench_function( + &format!("get all {:?} lines from textbuffer", nr_lines), + |b| b.iter(|| text_buf.all_lines(&arena)), + ); +} + +fn get_all_lines_bench(c: &mut Criterion) { + get_all_lines_helper(10000, c) +} + +fn get_line_len_helper(nr_lines: usize, c: &mut Criterion) { + let text_buf = buf_from_dummy_file(nr_lines); + + let mut rand_gen = StdRng::seed_from_u64(45); + + c.bench_function( + &format!("get random line len from {:?}-line textbuffer", nr_lines), + |b| b.iter(|| text_buf.line_len(rand_gen.gen_range(0..nr_lines)).unwrap()), + ); +} + +fn get_line_len_bench(c: &mut Criterion) { + get_line_len_helper(10000, c) +} + +fn get_line_helper(nr_lines: usize, c: &mut Criterion) { + let text_buf = buf_from_dummy_file(nr_lines); + + let mut rand_gen = StdRng::seed_from_u64(46); + + c.bench_function( + &format!("get random line from {:?}-line textbuffer", nr_lines), + |b| b.iter(|| text_buf.line(rand_gen.gen_range(0..nr_lines)).unwrap()), + ); +} + +fn get_line_bench(c: &mut Criterion) { + get_line_helper(10000, c) +} + +pub fn del_select_bench(c: &mut Criterion) { + let nr_lines = 25000000; + let mut text_buf = buf_from_dummy_file(nr_lines); + + let mut rand_gen = StdRng::seed_from_u64(47); + + c.bench_function( + &format!( + "delete rand selection, {}-line file", + text_buf.nr_of_lines() + ), + |b| { + b.iter(|| { + let rand_sel = gen_rand_selection(&mut rand_gen, &text_buf); + + text_buf.del_selection(rand_sel).unwrap(); + }) + }, + ); +} + +fn gen_rand_selection(rand_gen: &mut StdRng, text_buf: &TextBuffer) -> RawSelection { + let max_line_nr = text_buf.nr_of_lines(); + let rand_line_nr_a = rand_gen.gen_range(0..max_line_nr - 3); + let max_col_a = text_buf.line_len(rand_line_nr_a).expect(&format!( + "Failed to retrieve line length. For line {}, with {} lines in buffer", + rand_line_nr_a, + text_buf.nr_of_lines() + )); + let rand_col_a = if max_col_a > 0 { + rand_gen.gen_range(0..max_col_a) + } else { + 0 + }; + + let max_sel_end = min(rand_line_nr_a + 5, max_line_nr); + let rand_line_nr_b = rand_gen.gen_range((rand_line_nr_a + 1)..max_sel_end); + let max_col_b = text_buf.line_len(rand_line_nr_b).expect(&format!( + "Failed to retrieve line length. For line {}, with {} lines in buffer", + rand_line_nr_b, + text_buf.nr_of_lines() + )); + let rand_col_b = if max_col_b > 0 { + rand_gen.gen_range(0..max_col_b) + } else { + 0 + }; + + RawSelection { + start_pos: Position { + line: rand_line_nr_a, + column: rand_col_a, + }, + end_pos: Position { + line: rand_line_nr_b, + column: rand_col_b, + }, + } +} + +fn buf_from_dummy_file(nr_lines: usize) -> TextBuffer { + let path_str = format!("benches/resources/{}_lines.roc", nr_lines); + + text_buffer::from_path(Path::new(&path_str)).expect("Failed to read file at given path.") +} + +//TODO remove all random generation from inside measured execution block +//criterion_group!(benches, del_select_bench); +criterion_group!( + benches, + char_pop_bench, + char_insert_bench, + get_all_lines_bench, + get_line_len_bench, + get_line_bench, + del_select_bench +); +criterion_main!(benches); diff --git a/editor/benches/file_benchmark.rs b/editor/benches/file_benchmark.rs new file mode 100644 index 0000000000..ca254312b7 --- /dev/null +++ b/editor/benches/file_benchmark.rs @@ -0,0 +1,123 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::distributions::Alphanumeric; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use roc_editor::mvc::app_model::AppModel; +use roc_editor::mvc::ed_model::{EdModel, Position, RawSelection}; +use roc_editor::mvc::update::handle_new_char; +use roc_editor::text_buffer; +use roc_editor::text_buffer::TextBuffer; +use ropey::Rope; +use std::fs::File; +use std::io::Write; +use std::path::Path; + +// duplicate inside mvc::update +fn mock_app_model( + text_buf: TextBuffer, + caret_pos: Position, + selection_opt: Option, +) -> AppModel { + AppModel { + ed_model_opt: Some(EdModel { + text_buf, + caret_pos, + selection_opt, + glyph_dim_rect_opt: None, + has_focus: true, + }), + } +} + +fn text_buffer_from_str(lines_str: &str) -> TextBuffer { + TextBuffer { + text_rope: Rope::from_str(lines_str), + path_str: "".to_owned(), + } +} + +pub fn char_insert_benchmark(c: &mut Criterion) { + let text_buf = text_buffer_from_str(""); + + let caret_pos = Position { line: 0, column: 0 }; + + let selection_opt: Option = None; + let mut app_model = mock_app_model(text_buf, caret_pos, selection_opt); + c.bench_function("single char insert, small buffer", |b| { + b.iter(|| handle_new_char(&mut app_model, &'a')) + }); +} + +static ROC_SOURCE_START: &str = "interface LongStrProvider + exposes [ longStr ] + imports [] + +longStr : Str +longStr = + \"\"\""; + +static ROC_SOURCE_END: &str = "\"\"\""; + +fn line_count(lines: &str) -> usize { + lines.matches("\n").count() +} + +pub fn gen_file(nr_lines: usize) { + let nr_of_str_lines = nr_lines - line_count(ROC_SOURCE_START); + let path_str = format!("benches/resources/{:?}_lines.roc", nr_lines); + let path = Path::new(&path_str); + let display = path.display(); + + // Open a file in write-only mode, returns `io::Result` + let mut file = match File::create(&path) { + Err(why) => panic!("couldn't create {}: {}", display, why), + Ok(file) => file, + }; + + file.write(ROC_SOURCE_START.as_bytes()) + .expect("Failed to write String to file."); + + let mut rand_gen_line = StdRng::seed_from_u64(42); + + for _ in 0..nr_of_str_lines { + let line_len = rand_gen_line.gen_range(1..90); + + let char_seed = rand_gen_line.gen_range(0..1000); + + let mut rand_string: String = StdRng::seed_from_u64(char_seed) + .sample_iter(&Alphanumeric) + .take(line_len) + .map(char::from) + .collect(); + rand_string.push('\n'); + + file.write(rand_string.as_bytes()) + .expect("Failed to write String to file."); + } + + file.write(ROC_SOURCE_END.as_bytes()) + .expect("Failed to write String to file."); +} + +fn file_read_bench_helper(nr_lines: usize, c: &mut Criterion) { + let path_str = format!("benches/resources/{}_lines.roc", nr_lines); + text_buffer::from_path(Path::new(&path_str)).expect("Failed to read file at given path."); + c.bench_function( + &format!("read {:?} line file into textbuffer", nr_lines), + |b| b.iter(|| text_buffer::from_path(black_box(Path::new(&path_str)))), + ); +} + +fn file_read_bench(c: &mut Criterion) { + // generate dummy files + /*let lines_vec = vec![100, 500, 1000, 10000, 50000, 100000, 25000000]; + + for nr_lines in lines_vec.iter(){ + gen_file(*nr_lines); + }*/ + + file_read_bench_helper(10, c) +} + +criterion_group!(benches, file_read_bench); +criterion_main!(benches); diff --git a/editor/benches/my_benchmark.rs b/editor/benches/my_benchmark.rs deleted file mode 100644 index 7a77e9f937..0000000000 --- a/editor/benches/my_benchmark.rs +++ /dev/null @@ -1,50 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use roc_editor::mvc::app_model::AppModel; -use roc_editor::mvc::ed_model::{EdModel, Position, RawSelection}; -use roc_editor::mvc::update::handle_new_char; -use roc_editor::text_buffer::TextBuffer; -use ropey::Rope; - -// duplicate inside mvc::update -fn mock_app_model( - text_buf: TextBuffer, - caret_pos: Position, - selection_opt: Option, -) -> AppModel { - AppModel { - ed_model_opt: Some(EdModel { - text_buf, - caret_pos, - selection_opt, - glyph_dim_rect_opt: None, - has_focus: true, - }), - } -} - -fn text_buffer_from_str(lines_str: &str) -> TextBuffer { - TextBuffer { - text_rope: Rope::from_str(lines_str), - path_str: "".to_owned(), - } -} - -pub fn char_insert_benchmark(c: &mut Criterion) { - let text_buf = text_buffer_from_str(""); - - let caret_pos = Position { line: 0, column: 0 }; - - let selection_opt: Option = None; - let mut app_model = mock_app_model(text_buf, caret_pos, selection_opt); - c.bench_function("single char insert, small buffer", |b| { - b.iter(|| handle_new_char(&mut app_model, &'a')) - }); -} - -pub fn file_open_benchmark(c: &mut Criterion) { - ed_model::init_model(path) - //TODO continue here -} - -criterion_group!(benches, char_insert_benchmark); -criterion_main!(benches); diff --git a/editor/benches/resources/10_lines.roc b/editor/benches/resources/10_lines.roc new file mode 100644 index 0000000000..8b322cd59c --- /dev/null +++ b/editor/benches/resources/10_lines.roc @@ -0,0 +1,11 @@ +interface LongStrProvider + exposes [ longStr ] + imports [] + +longStr : Str +longStr = + """7vntt4wlBKiVkNss19DZlOfmSAyIzO5Ph8eckYgnctYDersOFs3AWOPHcONxI58DoTEwGKNLGkhrxwCD +gWxYsX9hlEuQ0vI4twHMqgj8F +Ox4pVYIxku15v1KaWahgjkJ8EBXMWhe5m2519wpEtP +HtaqU0XzVu1ix3jGAZ66UugNKJrVP8RVQm +""" \ No newline at end of file diff --git a/editor/benches/results/file_read.txt b/editor/benches/results/file_read.txt new file mode 100644 index 0000000000..d7fa8777f6 --- /dev/null +++ b/editor/benches/results/file_read.txt @@ -0,0 +1,26 @@ +System info: +- CPU: Intel i7 4770k +- SSD: Samsung 970 EVO PLUS M.2 1TB +- OS: Ubuntu 20.04 + +c.bench_function( + "read file into textbuffer", + |b| b.iter(|| text_buffer::from_path(black_box(Path::new(path_str)))) + ); + +10 lines, 285 B time: [3.2343 us] + +100 lines, 4.2 KiB time: [6.1810 us] + +500 lines, 22.2 KiB time: [15.689 us] + +1000 lines, 44.6 KiB time: [29.591 us] + +10000 lines, 448 KiB time: [376.22 us] + +50000 lines, 2.2 MiB time: [2.0329 ms] + +100000 lines, 4.4 MiB time: [4.4221 ms] + +25000000 lines, 1.1 GiB time: [1.1333 s] + diff --git a/editor/src/lib.rs b/editor/src/lib.rs index cfdf2c6eb2..6ebfd8be74 100644 --- a/editor/src/lib.rs +++ b/editor/src/lib.rs @@ -46,12 +46,12 @@ pub mod error; pub mod graphics; mod keyboard_input; pub mod lang; -//mod mvc; -pub mod mvc; // for benchmarking +mod mvc; +//pub mod mvc; // for benchmarking mod resources; mod selection; -//mod text_buffer; -pub mod text_buffer; // for benchmarking +mod text_buffer; +//pub mod text_buffer; // for benchmarking mod util; mod vec_result; diff --git a/editor/src/mvc/ed_model.rs b/editor/src/mvc/ed_model.rs index fc2cad09dd..6a127031a5 100644 --- a/editor/src/mvc/ed_model.rs +++ b/editor/src/mvc/ed_model.rs @@ -3,6 +3,7 @@ use crate::graphics::primitives::rect::Rect; use crate::text_buffer; use crate::text_buffer::TextBuffer; use std::cmp::Ordering; +use std::fmt; use std::path::Path; #[derive(Debug)] @@ -55,3 +56,13 @@ pub struct RawSelection { pub start_pos: Position, pub end_pos: Position, } + +impl std::fmt::Display for RawSelection { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "RawSelection: start_pos: line:{} col:{}, end_pos: line:{} col:{}", + self.start_pos.line, self.start_pos.column, self.end_pos.line, self.end_pos.column + ) + } +} diff --git a/editor/src/text_buffer.rs b/editor/src/text_buffer.rs index dc4fbdea1a..74ed23205f 100644 --- a/editor/src/text_buffer.rs +++ b/editor/src/text_buffer.rs @@ -91,7 +91,7 @@ impl TextBuffer { } // expensive function, don't use it if it can be done with a specialized, more efficient function - // TODO use bump allocation here + // TODO use pool allocation here pub fn all_lines<'a>(&self, arena: &'a Bump) -> BumpString<'a> { let mut lines = BumpString::with_capacity_in(self.text_rope.len_chars(), arena); diff --git a/examples/.gitignore b/examples/.gitignore index 95e293ab00..509ab35b68 100644 --- a/examples/.gitignore +++ b/examples/.gitignore @@ -7,4 +7,5 @@ roc_app.bc benchmarks/nqueens benchmarks/deriv benchmarks/cfold +benchmarks/rbtree-insert effect-example diff --git a/examples/benchmarks/NQueens.roc b/examples/benchmarks/NQueens.roc index 278aadc0cf..aa2a89140e 100644 --- a/examples/benchmarks/NQueens.roc +++ b/examples/benchmarks/NQueens.roc @@ -5,7 +5,7 @@ app "nqueens" main : Task.Task {} [] main = - queens 10 + queens 6 |> Str.fromInt |> Task.putLine diff --git a/examples/benchmarks/RBTreeCk.roc b/examples/benchmarks/RBTreeCk.roc new file mode 100644 index 0000000000..2a1d3a03a1 --- /dev/null +++ b/examples/benchmarks/RBTreeCk.roc @@ -0,0 +1,123 @@ +app "rbtree-ck" + packages { base: "platform" } + imports [base.Task] + provides [ main ] to base + + +Color : [ Red, Black ] + +Tree a b : [ Leaf, Node Color (Tree a b) a b (Tree a b) ] + +Map : Tree I64 Bool + +ConsList a : [ Nil, Cons a (ConsList a) ] + +makeMap : I64, I64 -> ConsList Map +makeMap = \freq, n -> + makeMapHelp freq n Leaf Nil + +makeMapHelp : I64, I64, Map, ConsList Map -> ConsList Map +makeMapHelp = \freq, n, m, acc -> + when n is + 0 -> Cons m acc + _ -> + powerOf10 = + (n % 10 |> resultWithDefault 0) == 0 + + + m1 = insert m n powerOf10 + + isFrequency = + (n % freq |> resultWithDefault 0) == 0 + + x = (if isFrequency then (Cons m1 acc) else acc) + makeMapHelp freq (n-1) m1 x + +fold : (a, b, omega -> omega), Tree a b, omega -> omega +fold = \f, tree, b -> + when tree is + Leaf -> b + Node _ l k v r -> fold f r (f k v (fold f l b)) + +resultWithDefault : Result a e, a -> a +resultWithDefault = \res, default -> + when res is + Ok v -> v + Err _ -> default + +main : Task.Task {} [] +main = + ms : ConsList Map + ms = makeMap 5 5 # 42_000_00 + + when ms is + Cons head _ -> + val = fold (\_, v, r -> if v then r + 1 else r) head 0 + val + |> Str.fromInt + |> Task.putLine + + Nil -> + Task.putLine "fail" + +insert : Map, I64, Bool -> Map +insert = \t, k, v -> if isRed t then setBlack (ins t k v) else ins t k v + + +setBlack : Tree a b -> Tree a b +setBlack = \tree -> + when tree is + Node _ l k v r -> Node Black l k v r + _ -> tree + +isRed : Tree a b -> Bool +isRed = \tree -> + when tree is + Node Red _ _ _ _ -> True + _ -> False + +lt = \x, y -> x < y + +ins : Map, I64, Bool -> Map +ins = \tree, kx, vx -> + when tree is + Leaf -> + Node Red Leaf kx vx Leaf + + Node Red a ky vy b -> + if lt kx ky then + Node Red (ins a kx vx) ky vy b + else if lt ky kx then + Node Red a ky vy (ins b kx vx) + else + Node Red a ky vy (ins b kx vx) + + Node Black a ky vy b -> + if lt kx ky then + (if isRed a then balance1 (Node Black Leaf ky vy b) (ins a kx vx) else Node Black (ins a kx vx) ky vy b) + else if lt ky kx then + (if isRed b then balance2 (Node Black a ky vy Leaf) (ins b kx vx) else Node Black a ky vy (ins b kx vx)) + else + Node Black a kx vx b + +balance1 : Map, Map -> Map +balance1 = \tree1, tree2 -> + when tree1 is + Leaf -> Leaf + Node _ _ kv vv t -> + when tree2 is + Node _ (Node Red l kx vx r1) ky vy r2 -> Node Red (Node Black l kx vx r1) ky vy (Node Black r2 kv vv t) + Node _ l1 ky vy (Node Red l2 kx vx r) -> Node Red (Node Black l1 ky vy l2) kx vx (Node Black r kv vv t) + Node _ l ky vy r -> Node Black (Node Red l ky vy r) kv vv t + Leaf -> Leaf + +balance2 : Map, Map -> Map +balance2 = \tree1, tree2 -> + when tree1 is + Leaf -> Leaf + Node _ t kv vv _ -> + when tree2 is + Node _ (Node Red l kx1 vx1 r1) ky vy r2 -> Node Red (Node Black t kv vv l) kx1 vx1 (Node Black r1 ky vy r2) + Node _ l1 ky vy (Node Red l2 kx2 vx2 r2) -> Node Red (Node Black t kv vv l1) ky vy (Node Black l2 kx2 vx2 r2) + Node _ l ky vy r -> Node Black t kv vv (Node Red l ky vy r) + Leaf -> Leaf diff --git a/examples/benchmarks/RBTreeDel.roc b/examples/benchmarks/RBTreeDel.roc new file mode 100644 index 0000000000..31cf63b2fa --- /dev/null +++ b/examples/benchmarks/RBTreeDel.roc @@ -0,0 +1,216 @@ +app "rbtree-del" + packages { base: "platform" } + imports [base.Task] + provides [ main ] to base + + +Color : [ Red, Black ] + +Tree a b : [ Leaf, Node Color (Tree a b) a b (Tree a b) ] + +Map : Tree I64 Bool + +ConsList a : [ Nil, Cons a (ConsList a) ] + +main : Task.Task {} [] +main = + # benchmarks use 4_200_000 + m = makeMap 420 + + val = fold (\_, v, r -> if v then r + 1 else r) m 0 + + val + |> Str.fromInt + |> Task.putLine + +boom : Str -> a +boom = \_ -> boom "" + +makeMap : I64 -> Map +makeMap = \n -> + makeMapHelp n n Leaf + +makeMapHelp : I64, I64, Map -> Map +makeMapHelp = \total, n, m -> + when n is + 0 -> m + _ -> + n1 = n - 1 + + powerOf10 = + (n % 10 |> resultWithDefault 0) == 0 + + t1 = insert m n powerOf10 + + isFrequency = + (n % 4 |> resultWithDefault 0) == 0 + + key = n1 + ((total - n1) // 5 |> resultWithDefault 0) + t2 = if isFrequency then delete t1 key else t1 + + makeMapHelp total n1 t2 + +fold : (a, b, omega -> omega), Tree a b, omega -> omega +fold = \f, tree, b -> + when tree is + Leaf -> b + Node _ l k v r -> fold f r (f k v (fold f l b)) + +depth : Tree * * -> I64 +depth = \tree -> + when tree is + Leaf -> 1 + Node _ l _ _ r -> 1 + depth l + depth r + +resultWithDefault : Result a e, a -> a +resultWithDefault = \res, default -> + when res is + Ok v -> v + Err _ -> default + + +insert : Map, I64, Bool -> Map +insert = \t, k, v -> if isRed t then setBlack (ins t k v) else ins t k v + + +setBlack : Tree a b -> Tree a b +setBlack = \tree -> + when tree is + Node _ l k v r -> Node Black l k v r + _ -> tree + +isRed : Tree a b -> Bool +isRed = \tree -> + when tree is + Node Red _ _ _ _ -> True + _ -> False + +lt = \x, y -> x < y + +ins : Map, I64, Bool -> Map +ins = \tree, kx, vx -> + when tree is + Leaf -> + Node Red Leaf kx vx Leaf + + Node Red a ky vy b -> + if lt kx ky then + Node Red (ins a kx vx) ky vy b + else if lt ky kx then + Node Red a ky vy (ins b kx vx) + else + Node Red a ky vy (ins b kx vx) + + Node Black a ky vy b -> + if lt kx ky then + (if isRed a then balanceLeft (ins a kx vx) ky vy b else Node Black (ins a kx vx) ky vy b) + else if lt ky kx then + (if isRed b then balanceRight a ky vy (ins b kx vx) else Node Black a ky vy (ins b kx vx)) + else Node Black a kx vx b + +balanceLeft : Tree a b, a, b, Tree a b -> Tree a b +balanceLeft = \l, k, v, r -> + when l is + Leaf -> Leaf + Node _ (Node Red lx kx vx rx) ky vy ry + -> Node Red (Node Black lx kx vx rx) ky vy (Node Black ry k v r) + Node _ ly ky vy (Node Red lx kx vx rx) + -> Node Red (Node Black ly ky vy lx) kx vx (Node Black rx k v r) + Node _ lx kx vx rx + -> Node Black (Node Red lx kx vx rx) k v r + +balanceRight : Tree a b, a, b, Tree a b -> Tree a b +balanceRight = \l, k, v, r -> + when r is + Leaf -> Leaf + Node _ (Node Red lx kx vx rx) ky vy ry + -> Node Red (Node Black l k v lx) kx vx (Node Black rx ky vy ry) + Node _ lx kx vx (Node Red ly ky vy ry) + -> Node Red (Node Black l k v lx) kx vx (Node Black ly ky vy ry) + Node _ lx kx vx rx + -> Node Black l k v (Node Red lx kx vx rx) + +isBlack : Color -> Bool +isBlack = \c -> + when c is + Black -> True + Red -> False + + +Del a b : [ Del (Tree a b) Bool ] + +setRed : Map -> Map +setRed = \t -> + when t is + Node _ l k v r -> Node Red l k v r + _ -> t + + + +makeBlack : Map -> Del I64 Bool +makeBlack = \t -> + when t is + Node Red l k v r -> Del (Node Black l k v r) False + _ -> Del t True + + +rebalanceLeft = \c, l, k, v, r -> + when l is + Node Black _ _ _ _ -> Del (balanceLeft (setRed l) k v r) (isBlack c) + Node Red lx kx vx rx -> Del (Node Black lx kx vx (balanceLeft (setRed rx) k v r)) False + _ -> boom "unreachable" + +rebalanceRight = \c, l, k, v, r -> + when r is + Node Black _ _ _ _ -> Del (balanceRight l k v (setRed r)) (isBlack c) + Node Red lx kx vx rx -> Del (Node Black (balanceRight l k v (setRed lx)) kx vx rx) False + _ -> boom "unreachable" + + + +delMin = \t -> + when t is + Node Black Leaf k v r -> + when r is + Leaf -> Delmin (Del Leaf True) k v + _ -> Delmin (Del (setBlack r) False) k v + + Node Red Leaf k v r -> + Delmin (Del r False) k v + + Node c l k v r -> + when delMin l is + Delmin (Del lx True) kx vx -> Delmin (rebalanceRight c lx k v r) kx vx + Delmin (Del lx False) kx vx -> Delmin (Del (Node c lx k v r) False) kx vx + + Leaf -> + Delmin (Del t False) 0 False + + + +delete : Map, I64 -> Map +delete = \t, k -> + when del t k is + Del tx _ -> setBlack tx + +del = \t, k -> + when t is + Leaf -> Del Leaf False + Node cx lx kx vx rx -> + if (k < kx) then + when (del lx k) is + Del ly True -> rebalanceRight cx ly kx vx rx + Del ly False -> Del (Node cx ly kx vx rx) False + + else if (k > kx) then + when (del rx k) is + Del ry True -> rebalanceLeft cx lx kx vx ry + Del ry False -> Del (Node cx lx kx vx ry) False + + else + when rx is + Leaf -> if isBlack cx then makeBlack lx else Del lx False + Node _ _ _ _ _ -> + when delMin rx is + Delmin (Del ry True) ky vy -> rebalanceLeft cx lx ky vy ry + Delmin (Del ry False) ky vy -> Del (Node cx lx ky vy ry) False diff --git a/examples/benchmarks/RBTreeInsert.roc b/examples/benchmarks/RBTreeInsert.roc new file mode 100644 index 0000000000..e12b4e700d --- /dev/null +++ b/examples/benchmarks/RBTreeInsert.roc @@ -0,0 +1,106 @@ +app "rbtree-insert" + packages { base: "platform" } + imports [base.Task] + provides [ main ] to base + +main : Task.Task {} [] +main = + tree : RedBlackTree I64 {} + tree = insert 0 {} Empty + + tree + |> show + |> Task.putLine + +show : RedBlackTree I64 {} -> Str +show = \tree -> showRBTree tree Str.fromInt (\{} -> "{}") + +showRBTree : RedBlackTree k v, (k -> Str), (v -> Str) -> Str +showRBTree = \tree, showKey, showValue -> + when tree is + Empty -> "Empty" + Node color key value left right -> + sColor = showColor color + sKey = showKey key + sValue = showValue value + sL = nodeInParens left showKey showValue + sR = nodeInParens right showKey showValue + "Node \(sColor) \(sKey) \(sValue) \(sL) \(sR)" + +nodeInParens : RedBlackTree k v, (k -> Str), (v -> Str) -> Str +nodeInParens = \tree, showKey, showValue -> + when tree is + Empty -> showRBTree tree showKey showValue + Node _ _ _ _ _ -> + inner = showRBTree tree showKey showValue + "(\(inner))" + +showColor : NodeColor -> Str +showColor = \color -> + when color is + Red -> "Red" + Black -> "Black" + +NodeColor : [ Red, Black ] + +RedBlackTree k v : [ Node NodeColor k v (RedBlackTree k v) (RedBlackTree k v), Empty ] + +Key k : Num k + +insert : Key k, v, RedBlackTree (Key k) v -> RedBlackTree (Key k) v +insert = \key, value, dict -> + when insertHelp key value dict is + Node Red k v l r -> + Node Black k v l r + + x -> + x + +insertHelp : (Key k), v, RedBlackTree (Key k) v -> RedBlackTree (Key k) v +insertHelp = \key, value, dict -> + when dict is + Empty -> + # New nodes are always red. If it violates the rules, it will be fixed + # when balancing. + Node Red key value Empty Empty + + Node nColor nKey nValue nLeft nRight -> + when Num.compare key nKey is + LT -> + balance nColor nKey nValue (insertHelp key value nLeft) nRight + + EQ -> + Node nColor nKey value nLeft nRight + + GT -> + balance nColor nKey nValue nLeft (insertHelp key value nRight) + +balance : NodeColor, k, v, RedBlackTree k v, RedBlackTree k v -> RedBlackTree k v +balance = \color, key, value, left, right -> + when right is + Node Red rK rV rLeft rRight -> + when left is + Node Red lK lV lLeft lRight -> + Node + Red + key + value + (Node Black lK lV lLeft lRight) + (Node Black rK rV rLeft rRight) + + _ -> + Node color rK rV (Node Red key value left rLeft) rRight + + _ -> + when left is + Node Red lK lV (Node Red llK llV llLeft llRight) lRight -> + Node + Red + lK + lV + (Node Black llK llV llLeft llRight) + (Node Black key value lRight right) + + _ -> + Node color key value left right + diff --git a/examples/benchmarks/platform/host.zig b/examples/benchmarks/platform/host.zig index 041e1b7e94..af91e0bbb4 100644 --- a/examples/benchmarks/platform/host.zig +++ b/examples/benchmarks/platform/host.zig @@ -88,13 +88,8 @@ fn call_the_closure(function_pointer: *const u8, closure_data_pointer: [*]u8) vo pub export fn roc_fx_putLine(rocPath: str.RocStr) i64 { const stdout = std.io.getStdOut().writer(); - const u8_ptr = rocPath.asU8ptr(); - - var i: usize = 0; - while (i < rocPath.len()) { - stdout.print("{c}", .{u8_ptr[i]}) catch unreachable; - - i += 1; + for (rocPath.asSlice()) |char| { + stdout.print("{c}", .{char}) catch unreachable; } stdout.print("\n", .{}) catch unreachable; diff --git a/examples/benchmarks/platform/str.zig b/examples/benchmarks/platform/str.zig index 163c58a123..7c4a2cbf8c 100644 --- a/examples/benchmarks/platform/str.zig +++ b/examples/benchmarks/platform/str.zig @@ -260,6 +260,11 @@ pub const RocStr = extern struct { } }; +// Str.equal +pub fn strEqual(self: RocStr, other: RocStr) callconv(.C) bool { + return self.eq(other); +} + // Str.numberOfBytes pub fn strNumberOfBytes(string: RocStr) callconv(.C) usize { return string.len(); @@ -603,100 +608,8 @@ test "countSegments: delimiter interspered" { expectEqual(segments_count, 3); } -// Str.countGraphemeClusters -const grapheme = @import("helpers/grapheme.zig"); -pub fn countGraphemeClusters(string: RocStr) callconv(.C) usize { - if (string.isEmpty()) { - return 0; - } - - const bytes_len = string.len(); - const bytes_ptr = string.asU8ptr(); - - var bytes = bytes_ptr[0..bytes_len]; - var iter = (unicode.Utf8View.init(bytes) catch unreachable).iterator(); - - var count: usize = 0; - var grapheme_break_state: ?grapheme.BoundClass = null; - var grapheme_break_state_ptr = &grapheme_break_state; - var opt_last_codepoint: ?u21 = null; - while (iter.nextCodepoint()) |cur_codepoint| { - if (opt_last_codepoint) |last_codepoint| { - var did_break = grapheme.isGraphemeBreak(last_codepoint, cur_codepoint, grapheme_break_state_ptr); - if (did_break) { - count += 1; - grapheme_break_state = null; - } - } - opt_last_codepoint = cur_codepoint; - } - - // If there are no breaks, but the str is not empty, then there - // must be a single grapheme - if (bytes_len != 0) { - count += 1; - } - - return count; -} - fn rocStrFromLiteral(bytes_arr: *const []u8) RocStr {} -test "countGraphemeClusters: empty string" { - const count = countGraphemeClusters(RocStr.empty()); - expectEqual(count, 0); -} - -test "countGraphemeClusters: ascii characters" { - const bytes_arr = "abcd"; - const bytes_len = bytes_arr.len; - const str = RocStr.init(testing.allocator, bytes_arr, bytes_len); - defer str.deinit(testing.allocator); - - const count = countGraphemeClusters(str); - expectEqual(count, 4); -} - -test "countGraphemeClusters: utf8 characters" { - const bytes_arr = "ãxā"; - const bytes_len = bytes_arr.len; - const str = RocStr.init(testing.allocator, bytes_arr, bytes_len); - defer str.deinit(testing.allocator); - - const count = countGraphemeClusters(str); - expectEqual(count, 3); -} - -test "countGraphemeClusters: emojis" { - const bytes_arr = "🤔🤔🤔"; - const bytes_len = bytes_arr.len; - const str = RocStr.init(testing.allocator, bytes_arr, bytes_len); - defer str.deinit(testing.allocator); - - const count = countGraphemeClusters(str); - expectEqual(count, 3); -} - -test "countGraphemeClusters: emojis and ut8 characters" { - const bytes_arr = "🤔å🤔¥🤔ç"; - const bytes_len = bytes_arr.len; - const str = RocStr.init(testing.allocator, bytes_arr, bytes_len); - defer str.deinit(testing.allocator); - - const count = countGraphemeClusters(str); - expectEqual(count, 6); -} - -test "countGraphemeClusters: emojis, ut8, and ascii characters" { - const bytes_arr = "6🤔å🤔e¥🤔çpp"; - const bytes_len = bytes_arr.len; - const str = RocStr.init(testing.allocator, bytes_arr, bytes_len); - defer str.deinit(testing.allocator); - - const count = countGraphemeClusters(str); - expectEqual(count, 10); -} - // Str.startsWith pub fn startsWith(string: RocStr, prefix: RocStr) callconv(.C) bool { const bytes_len = string.len(); @@ -826,9 +739,7 @@ fn strConcatHelp(allocator: *Allocator, comptime T: type, result_in_place: InPla var result = RocStr.initBig(allocator, T, result_in_place, combined_length); { - const old_if_small = &@bitCast([16]u8, arg1); - const old_if_big = @ptrCast([*]u8, arg1.str_bytes); - const old_bytes = if (arg1.isSmallStr()) old_if_small else old_if_big; + const old_bytes = arg1.asU8ptr(); const new_bytes: [*]u8 = @ptrCast([*]u8, result.str_bytes); @@ -836,9 +747,7 @@ fn strConcatHelp(allocator: *Allocator, comptime T: type, result_in_place: InPla } { - const old_if_small = &@bitCast([16]u8, arg2); - const old_if_big = @ptrCast([*]u8, arg2.str_bytes); - const old_bytes = if (arg2.isSmallStr()) old_if_small else old_if_big; + const old_bytes = arg2.asU8ptr(); const new_bytes = @ptrCast([*]u8, result.str_bytes) + arg1.len();