diff --git a/compiler/builtins/bitcode/src/str.zig b/compiler/builtins/bitcode/src/str.zig index 1cd2ba8a2d..67f4892bba 100644 --- a/compiler/builtins/bitcode/src/str.zig +++ b/compiler/builtins/bitcode/src/str.zig @@ -1511,11 +1511,13 @@ pub fn strTrim(string: RocStr) callconv(.C) RocStr { return RocStr.empty(); } + // SIGSEGV is after this + const leading_bytes = countLeadingWhitespaceBytes(string); const trailing_bytes = countTrailingWhitespaceBytes(string); const new_len = string.len() - leading_bytes - trailing_bytes; - if (new_len == 0) { + if (new_len <= 0) { return RocStr.empty(); } @@ -1525,8 +1527,8 @@ pub fn strTrim(string: RocStr) callconv(.C) RocStr { // could also just inline the unsafe reallocate call // SIGSEGV is not from this branch - if (string.isRefcountOne()) { - const dest = string.str_bytes orelse unreachable; + if (string.isRefcountOne() and !string.isSmallStr()) { + const dest = string.str_bytes orelse return RocStr.empty(); const source = dest + leading_bytes; @memcpy(dest, source, new_len); return string.reallocate(new_len); @@ -1542,7 +1544,7 @@ fn countLeadingWhitespaceBytes(string: RocStr) usize { var iter = unicode.Utf8View.initUnchecked(bytes).iterator(); while (iter.nextCodepoint()) |codepoint| { if (isWhitespace(codepoint)) { - byte_count += unicode.utf8CodepointSequenceLength(codepoint) catch unreachable; + byte_count += unicode.utf8CodepointSequenceLength(codepoint) catch break; } else { break; } @@ -1558,7 +1560,7 @@ fn countTrailingWhitespaceBytes(string: RocStr) usize { var iter = ReverseUtf8View.initUnchecked(bytes).iterator(); while (iter.nextCodepoint()) |codepoint| { if (isWhitespace(codepoint)) { - byte_count += unicode.utf8CodepointSequenceLength(codepoint) catch unreachable; + byte_count += unicode.utf8CodepointSequenceLength(codepoint) catch break; } else { break; } @@ -1598,6 +1600,11 @@ const ReverseUtf8Iterator = struct { it.i -= 1; } + // TODO this should be unnecessary; it means invalid utf8 + if (it.i < 0) { + return null; + } + const cp_len = unicode.utf8ByteSequenceLength(it.bytes[it.i]) catch unreachable; const slice = it.bytes[it.i .. it.i + cp_len]; it.i -= 1; diff --git a/compiler/test_gen/src/gen_str.rs b/compiler/test_gen/src/gen_str.rs index 1172f62cdb..90f41063c1 100644 --- a/compiler/test_gen/src/gen_str.rs +++ b/compiler/test_gen/src/gen_str.rs @@ -1005,3 +1005,12 @@ fn str_trim_hello_world() { RocStr ); } + +#[test] +fn str_trim_hello_world_both_large() { + assert_evals_to!( + indoc!(r#"Str.trim " hello world world ""#), + RocStr::from("hello world world"), + RocStr + ); +}