diff --git a/compiler/gen/src/llvm/build_str.rs b/compiler/gen/src/llvm/build_str.rs index 10a06cd363..ddf2fdf564 100644 --- a/compiler/gen/src/llvm/build_str.rs +++ b/compiler/gen/src/llvm/build_str.rs @@ -24,15 +24,19 @@ pub fn str_concat<'a, 'ctx, 'env>( let second_str_ptr = ptr_from_symbol(scope, second_str_symbol); let first_str_ptr = ptr_from_symbol(scope, first_str_symbol); + let str_wrapper_type = BasicTypeEnum::StructType(collection(ctx, env.ptr_bytes)); + load_str( env, parent, second_str_ptr.clone(), + str_wrapper_type, |second_str_ptr, second_str_len, second_str_smallness| { load_str( env, parent, first_str_ptr.clone(), + str_wrapper_type, |first_str_ptr, first_str_len, first_str_smallness| { // first_str_len > 0 // We do this check to avoid allocating memory. If the first input @@ -66,7 +70,7 @@ pub fn str_concat<'a, 'ctx, 'env>( second_str_length_comparison, if_second_str_is_nonempty, if_second_str_is_empty, - BasicTypeEnum::StructType(collection(ctx, env.ptr_bytes)), + str_wrapper_type, ) }; @@ -199,19 +203,26 @@ pub fn str_concat<'a, 'ctx, 'env>( ) } -pub fn str_len<'a, 'ctx, 'env>( +fn str_len_from_final_byte<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + final_byte: IntValue<'ctx>, +) -> IntValue<'ctx> { + let builder = env.builder; + let ctx = env.context; + let bitmask = ctx.i8_type().const_int(0b0111_1111, false); + + builder.build_and(final_byte, bitmask, "small_str_length") +} + +#[allow(dead_code)] +fn str_len<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, wrapper_ptr: PointerValue<'ctx>, ) -> IntValue<'ctx> { let builder = env.builder; - let ctx = env.context; - let if_small = |final_byte| { - let bitmask = ctx.i8_type().const_int(0b0111_1111, false); - - BasicValueEnum::IntValue(builder.build_and(final_byte, bitmask, "small_str_length")) - }; + let if_small = |final_byte| BasicValueEnum::IntValue(str_len_from_final_byte(env, final_byte)); let if_big = |_| { BasicValueEnum::IntValue(list_len( @@ -237,21 +248,19 @@ fn load_str<'a, 'ctx, 'env, Callback>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, wrapper_ptr: PointerValue<'ctx>, + ret_type: BasicTypeEnum<'ctx>, cb: Callback, ) -> BasicValueEnum<'ctx> where Callback: Fn(PointerValue<'ctx>, IntValue<'ctx>, Smallness) -> BasicValueEnum<'ctx>, { let builder = env.builder; - let ctx = env.context; let if_small = |final_byte| { - let bitmask = ctx.i8_type().const_int(0b0111_1111, false); - - let len = builder.build_and(final_byte, bitmask, "small_str_length"); + let len = str_len_from_final_byte(env, final_byte); cb( - wrapper_ptr, + cast_str_wrapper_to_array(env, wrapper_ptr), builder.build_int_cast(len, env.ptr_int(), "len_as_usize"), Smallness::Small, ) @@ -267,14 +276,7 @@ where cb(list_ptr, list_len(builder, wrapper_struct), Smallness::Big) }; - if_small_str( - env, - parent, - wrapper_ptr, - if_small, - if_big, - BasicTypeEnum::IntType(env.ptr_int()), - ) + if_small_str(env, parent, wrapper_ptr, if_small, if_big, ret_type) } #[derive(Debug, Copy, Clone)] @@ -296,9 +298,10 @@ fn clone_nonempty_str<'a, 'ctx, 'env>( // Allocate space for the new str that we'll copy into. match smallness { Smallness::Small => { - let wrapper_struct = builder.build_load(bytes_ptr, "str_wrapper"); - + let wrapper_struct_ptr = cast_str_bytes_to_wrapper(env, bytes_ptr); + let wrapper_struct = builder.build_load(wrapper_struct_ptr, "str_wrapper"); let alloca = builder.build_alloca(collection(ctx, ptr_bytes), "small_str_clone"); + builder.build_store(alloca, wrapper_struct); (wrapper_struct.into_struct_value(), alloca) @@ -346,7 +349,29 @@ fn clone_nonempty_str<'a, 'ctx, 'env>( } } -pub fn if_small_str<'a, 'ctx, 'env, IfSmallFn, IfBigFn>( +fn cast_str_bytes_to_wrapper<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + bytes_ptr: PointerValue<'ctx>, +) -> PointerValue<'ctx> { + let struct_ptr_type = collection(env.context, env.ptr_bytes).ptr_type(AddressSpace::Generic); + + env.builder + .build_bitcast(bytes_ptr, struct_ptr_type, "str_as_struct_ptr") + .into_pointer_value() +} + +fn cast_str_wrapper_to_array<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + wrapper_ptr: PointerValue<'ctx>, +) -> PointerValue<'ctx> { + let array_ptr_type = env.context.i8_type().ptr_type(AddressSpace::Generic); + + env.builder + .build_bitcast(wrapper_ptr, array_ptr_type, "str_as_array_ptr") + .into_pointer_value() +} + +fn if_small_str<'a, 'ctx, 'env, IfSmallFn, IfBigFn>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, wrapper_ptr: PointerValue<'ctx>, @@ -360,13 +385,7 @@ where { let builder = env.builder; let ctx = env.context; - - let array_ptr_type = ctx.i8_type().ptr_type(AddressSpace::Generic); - - let byte_array_ptr = builder - .build_bitcast(wrapper_ptr, array_ptr_type, "str_as_array_ptr") - .into_pointer_value(); - + let byte_array_ptr = cast_str_wrapper_to_array(env, wrapper_ptr); let final_byte_ptr = unsafe { builder.build_in_bounds_gep( byte_array_ptr,