From 065fa748cf8498fa346780e90202c81c1dbd0f5c Mon Sep 17 00:00:00 2001 From: Leonard Hecker Date: Thu, 5 Jun 2025 19:34:07 +0200 Subject: [PATCH] Add specialized SIMD line seeking routines (#408) The previous `memchr` loop had the fatal flaw that it would break out of the SIMD routines every time it hit a newline. This resulted in a throughput drop down to ~250MB/s on my system in the worst case. By writing SIMD routines specific to newline seeking, we can bump that up by >500x. Navigating through a 1GB of text now takes ~16ms independent of the contents. --- benches/lib.rs | 38 +++-- src/bin/edit/documents.rs | 9 +- src/buffer/mod.rs | 17 ++- src/simd/lines_bwd.rs | 283 +++++++++++++++++++++++++++++++++++++ src/simd/lines_fwd.rs | 281 ++++++++++++++++++++++++++++++++++++ src/simd/memrchr2.rs | 194 ------------------------- src/simd/mod.rs | 35 ++++- src/tui.rs | 4 +- src/unicode/measurement.rs | 116 --------------- 9 files changed, 643 insertions(+), 334 deletions(-) create mode 100644 src/simd/lines_bwd.rs create mode 100644 src/simd/lines_fwd.rs delete mode 100644 src/simd/memrchr2.rs diff --git a/benches/lib.rs b/benches/lib.rs index abc9ee1..738dbc0 100644 --- a/benches/lib.rs +++ b/benches/lib.rs @@ -3,7 +3,7 @@ use std::hint::black_box; use std::io::Cursor; -use std::mem; +use std::{mem, vec}; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use edit::helpers::*; @@ -133,18 +133,36 @@ fn bench_oklab(c: &mut Criterion) { }); } +fn bench_simd_lines_fwd(c: &mut Criterion) { + let mut group = c.benchmark_group("simd"); + let buf = vec![b'\n'; 128 * MEBI]; + + for &lines in &[1, 8, 128, KIBI, 128 * KIBI, 128 * MEBI] { + group.throughput(Throughput::Bytes(lines as u64)).bench_with_input( + BenchmarkId::new("lines_fwd", lines), + &lines, + |b, &lines| { + b.iter(|| simd::lines_fwd(black_box(&buf), 0, 0, lines as CoordType)); + }, + ); + } +} + fn bench_simd_memchr2(c: &mut Criterion) { let mut group = c.benchmark_group("simd"); - let mut buffer_u8 = [0u8; 2048]; + let mut buf = vec![0u8; 128 * MEBI + KIBI]; - for &bytes in &[8usize, 32 + 8, 64 + 8, KIBI + 8] { + // For small sizes we add a small offset of +8, + // to ensure we also benchmark the non-SIMD tail handling. + // For large sizes, its relative impact is negligible. + for &bytes in &[8usize, 128 + 8, KIBI, 128 * KIBI, 128 * MEBI] { group.throughput(Throughput::Bytes(bytes as u64 + 1)).bench_with_input( BenchmarkId::new("memchr2", bytes), &bytes, |b, &size| { - buffer_u8.fill(b'a'); - buffer_u8[size] = b'\n'; - b.iter(|| simd::memchr2(b'\n', b'\r', black_box(&buffer_u8), 0)); + buf.fill(b'a'); + buf[size] = b'\n'; + b.iter(|| simd::memchr2(b'\n', b'\r', black_box(&buf), 0)); }, ); } @@ -154,9 +172,12 @@ fn bench_simd_memset(c: &mut Criterion) { let mut group = c.benchmark_group("simd"); let name = format!("memset<{}>", std::any::type_name::()); let size = mem::size_of::(); - let mut buf: Vec = vec![Default::default(); 2048 / size]; + let mut buf: Vec = vec![Default::default(); 128 * MEBI / size]; - for &bytes in &[8usize, 32 + 8, 64 + 8, KIBI + 8] { + // For small sizes we add a small offset of +8, + // to ensure we also benchmark the non-SIMD tail handling. + // For large sizes, its relative impact is negligible. + for &bytes in &[8usize, 128 + 8, KIBI, 128 * KIBI, 128 * MEBI] { group.throughput(Throughput::Bytes(bytes as u64)).bench_with_input( BenchmarkId::new(&name, bytes), &bytes, @@ -206,6 +227,7 @@ fn bench(c: &mut Criterion) { bench_buffer(c); bench_hash(c); bench_oklab(c); + bench_simd_lines_fwd(c); bench_simd_memchr2(c); bench_simd_memset::(c); bench_simd_memset::(c); diff --git a/src/bin/edit/documents.rs b/src/bin/edit/documents.rs index d221db8..33fc8cf 100644 --- a/src/bin/edit/documents.rs +++ b/src/bin/edit/documents.rs @@ -8,7 +8,6 @@ use std::path::{Path, PathBuf}; use edit::buffer::{RcTextBuffer, TextBuffer}; use edit::helpers::{CoordType, Point}; -use edit::simd::memrchr2; use edit::{apperr, path, sys}; use crate::state::DisplayablePathBuf; @@ -244,8 +243,12 @@ impl DocumentManager { Some(num) } + fn find_colon_rev(bytes: &[u8], offset: usize) -> Option { + (0..offset.min(bytes.len())).rev().find(|&i| bytes[i] == b':') + } + let bytes = path.as_os_str().as_encoded_bytes(); - let colend = match memrchr2(b':', b':', bytes, bytes.len()) { + let colend = match find_colon_rev(bytes, bytes.len()) { // Reject filenames that would result in an empty filename after stripping off the :line:char suffix. // For instance, a filename like ":123:456" will not be processed by this function. Some(colend) if colend > 0 => colend, @@ -260,7 +263,7 @@ impl DocumentManager { let mut len = colend; let mut goto = Point { x: 0, y: last }; - if let Some(colbeg) = memrchr2(b':', b':', bytes, colend) { + if let Some(colbeg) = find_colon_rev(bytes, colend) { // Same here: Don't allow empty filenames. if colbeg != 0 && let Some(first) = parse(&bytes[colbeg + 1..colend]) diff --git a/src/buffer/mod.rs b/src/buffer/mod.rs index d341db8..04c0179 100644 --- a/src/buffer/mod.rs +++ b/src/buffer/mod.rs @@ -44,7 +44,7 @@ use crate::helpers::*; use crate::oklab::oklab_blend; use crate::simd::memchr2; use crate::unicode::{self, Cursor, MeasurementConfig}; -use crate::{apperr, icu}; +use crate::{apperr, icu, simd}; /// The margin template is used for line numbers. /// The max. line number we should ever expect is probably 64-bit, @@ -341,7 +341,7 @@ impl TextBuffer { break 'outer; } - let (delta, line) = unicode::newlines_forward(chunk, 0, 0, 1); + let (delta, line) = simd::lines_fwd(chunk, 0, 0, 1); off += delta; if line == 1 { break; @@ -684,7 +684,7 @@ impl TextBuffer { } } - (offset, lines) = unicode::newlines_forward(chunk, offset, lines, lines + 1); + (offset, lines) = simd::lines_fwd(chunk, offset, lines, lines + 1); // Check if the preceding line ended in CRLF. if offset >= 2 && &chunk[offset - 2..offset] == b"\r\n" { @@ -723,7 +723,7 @@ impl TextBuffer { // If the file has more than 1000 lines, figure out how many are remaining. if offset < chunk.len() { - (_, lines) = unicode::newlines_forward(chunk, offset, lines, CoordType::MAX); + (_, lines) = simd::lines_fwd(chunk, offset, lines, CoordType::MAX); } let final_newline = chunk.ends_with(b"\n"); @@ -1219,7 +1219,7 @@ impl TextBuffer { break; } - let (delta, line) = unicode::newlines_forward(chunk, 0, result.logical_pos.y, y); + let (delta, line) = simd::lines_fwd(chunk, 0, result.logical_pos.y, y); result.offset += delta; result.logical_pos.y = line; } @@ -1239,8 +1239,7 @@ impl TextBuffer { break; } - let (delta, line) = - unicode::newlines_backward(chunk, chunk.len(), result.logical_pos.y, y); + let (delta, line) = simd::lines_bwd(chunk, chunk.len(), result.logical_pos.y, y); result.offset -= chunk.len() - delta; result.logical_pos.y = line; if delta > 0 { @@ -2082,7 +2081,7 @@ impl TextBuffer { selection_end.x -= remove as CoordType; } - (offset, y) = unicode::newlines_forward(&replacement, offset, y, y + 1); + (offset, y) = simd::lines_fwd(&replacement, offset, y, y + 1); } if replacement.len() == initial_len { @@ -2376,7 +2375,7 @@ impl TextBuffer { let mut offset = cursor.offset; while beg < added.len() { - let (end, line) = unicode::newlines_forward(added, beg, 0, 1); + let (end, line) = simd::lines_fwd(added, beg, 0, 1); let has_newline = line != 0; let link = &added[beg..end]; let line = unicode::strip_newline(link); diff --git a/src/simd/lines_bwd.rs b/src/simd/lines_bwd.rs new file mode 100644 index 0000000..dbe59ad --- /dev/null +++ b/src/simd/lines_bwd.rs @@ -0,0 +1,283 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::ptr; + +use crate::helpers::CoordType; + +/// Starting from the `offset` in `haystack` with a current line index of +/// `line`, this seeks backwards to the `line_stop`-nth line and returns the +/// new offset and the line index at that point. +/// +/// Note that this function differs from `lines_fwd` in that it +/// seeks backwards even if the `line` is already at `line_stop`. +/// This allows you to ensure (or test) whether `offset` is at a line start. +/// +/// It returns an offset *past* a newline and thus at the start of a line. +pub fn lines_bwd( + haystack: &[u8], + offset: usize, + line: CoordType, + line_stop: CoordType, +) -> (usize, CoordType) { + unsafe { + let beg = haystack.as_ptr(); + let it = beg.add(offset.min(haystack.len())); + let (it, line) = lines_bwd_raw(beg, it, line, line_stop); + (it.offset_from_unsigned(beg), line) + } +} + +unsafe fn lines_bwd_raw( + beg: *const u8, + end: *const u8, + line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + #[cfg(target_arch = "x86_64")] + return unsafe { LINES_BWD_DISPATCH(beg, end, line, line_stop) }; + + #[cfg(target_arch = "aarch64")] + return unsafe { lines_bwd_neon(beg, end, line, line_stop) }; + + #[allow(unreachable_code)] + return unsafe { lines_bwd_fallback(beg, end, line, line_stop) }; +} + +unsafe fn lines_bwd_fallback( + beg: *const u8, + mut end: *const u8, + mut line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + unsafe { + while !ptr::eq(end, beg) { + let n = end.sub(1); + if *n == b'\n' { + if line <= line_stop { + break; + } + line -= 1; + } + end = n; + } + (end, line) + } +} + +#[cfg(target_arch = "x86_64")] +static mut LINES_BWD_DISPATCH: unsafe fn( + beg: *const u8, + end: *const u8, + line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) = lines_bwd_dispatch; + +#[cfg(target_arch = "x86_64")] +unsafe fn lines_bwd_dispatch( + beg: *const u8, + end: *const u8, + line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + let func = if is_x86_feature_detected!("avx2") { lines_bwd_avx2 } else { lines_bwd_fallback }; + unsafe { LINES_BWD_DISPATCH = func }; + unsafe { func(beg, end, line, line_stop) } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn lines_bwd_avx2( + beg: *const u8, + mut end: *const u8, + mut line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + unsafe { + use std::arch::x86_64::*; + + #[inline(always)] + unsafe fn horizontal_sum_i64(v: __m256i) -> i64 { + unsafe { + let hi = _mm256_extracti128_si256::<1>(v); + let lo = _mm256_castsi256_si128(v); + let sum = _mm_add_epi64(lo, hi); + let shuf = _mm_shuffle_epi32::<0b11_10_11_10>(sum); + let sum = _mm_add_epi64(sum, shuf); + _mm_cvtsi128_si64(sum) + } + } + + let lf = _mm256_set1_epi8(b'\n' as i8); + let line_stop = line_stop.min(line); + let mut remaining = end.offset_from_unsigned(beg); + + while remaining >= 128 { + let chunk_start = end.sub(128); + + let v1 = _mm256_loadu_si256(chunk_start.add(0) as *const _); + let v2 = _mm256_loadu_si256(chunk_start.add(32) as *const _); + let v3 = _mm256_loadu_si256(chunk_start.add(64) as *const _); + let v4 = _mm256_loadu_si256(chunk_start.add(96) as *const _); + + let mut sum = _mm256_setzero_si256(); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v1, lf)); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v2, lf)); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v3, lf)); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v4, lf)); + + let sum = _mm256_sad_epu8(sum, _mm256_setzero_si256()); + let sum = horizontal_sum_i64(sum); + + let line_next = line - sum as CoordType; + if line_next <= line_stop { + break; + } + + end = chunk_start; + remaining -= 128; + line = line_next; + } + + while remaining >= 32 { + let chunk_start = end.sub(32); + let v = _mm256_loadu_si256(chunk_start as *const _); + let c = _mm256_cmpeq_epi8(v, lf); + + let ones = _mm256_and_si256(c, _mm256_set1_epi8(0x01)); + let sum = _mm256_sad_epu8(ones, _mm256_setzero_si256()); + let sum = horizontal_sum_i64(sum); + + let line_next = line - sum as CoordType; + if line_next <= line_stop { + break; + } + + end = chunk_start; + remaining -= 32; + line = line_next; + } + + lines_bwd_fallback(beg, end, line, line_stop) + } +} + +#[cfg(target_arch = "aarch64")] +unsafe fn lines_bwd_neon( + beg: *const u8, + mut end: *const u8, + mut line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + unsafe { + use std::arch::aarch64::*; + + let lf = vdupq_n_u8(b'\n'); + let line_stop = line_stop.min(line); + let mut remaining = end.offset_from_unsigned(beg); + + while remaining >= 64 { + let chunk_start = end.sub(64); + + let v1 = vld1q_u8(chunk_start.add(0)); + let v2 = vld1q_u8(chunk_start.add(16)); + let v3 = vld1q_u8(chunk_start.add(32)); + let v4 = vld1q_u8(chunk_start.add(48)); + + let mut sum = vdupq_n_u8(0); + sum = vsubq_u8(sum, vceqq_u8(v1, lf)); + sum = vsubq_u8(sum, vceqq_u8(v2, lf)); + sum = vsubq_u8(sum, vceqq_u8(v3, lf)); + sum = vsubq_u8(sum, vceqq_u8(v4, lf)); + + let sum = vaddvq_u8(sum); + + let line_next = line - sum as CoordType; + if line_next <= line_stop { + break; + } + + end = chunk_start; + remaining -= 64; + line = line_next; + } + + while remaining >= 16 { + let chunk_start = end.sub(16); + let v = vld1q_u8(chunk_start); + let c = vceqq_u8(v, lf); + let c = vandq_u8(c, vdupq_n_u8(0x01)); + let sum = vaddvq_u8(c); + + let line_next = line - sum as CoordType; + if line_next <= line_stop { + break; + } + + end = chunk_start; + remaining -= 16; + line = line_next; + } + + lines_bwd_fallback(beg, end, line, line_stop) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::helpers::CoordType; + use crate::simd::test::*; + + #[test] + fn pseudo_fuzz() { + let text = generate_random_text(1024); + let lines = count_lines(&text); + let mut offset_rng = make_rng(); + let mut line_rng = make_rng(); + let mut line_distance_rng = make_rng(); + + for _ in 0..1000 { + let offset = offset_rng() % (text.len() + 1); + let line_stop = line_distance_rng() % (lines + 1); + let line = line_stop + line_rng() % 100; + + let line = line as CoordType; + let line_stop = line_stop as CoordType; + + let expected = reference_lines_bwd(text.as_bytes(), offset, line, line_stop); + let actual = lines_bwd(text.as_bytes(), offset, line, line_stop); + + assert_eq!(expected, actual); + } + } + + fn reference_lines_bwd( + haystack: &[u8], + mut offset: usize, + mut line: CoordType, + line_stop: CoordType, + ) -> (usize, CoordType) { + if line >= line_stop { + while offset > 0 { + let c = haystack[offset - 1]; + if c == b'\n' { + if line == line_stop { + break; + } + line -= 1; + } + offset -= 1; + } + } + (offset, line) + } + #[test] + fn seeks_to_start() { + for i in 6..=11 { + let (off, line) = lines_bwd(b"Hello\nWorld\n", i, 123, 456); + assert_eq!(off, 6); // After "Hello\n" + assert_eq!(line, 123); // Still on the same line + } + } +} diff --git a/src/simd/lines_fwd.rs b/src/simd/lines_fwd.rs new file mode 100644 index 0000000..e2d11f1 --- /dev/null +++ b/src/simd/lines_fwd.rs @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use std::ptr; + +use crate::helpers::CoordType; + +/// Starting from the `offset` in `haystack` with a current line index of +/// `line`, this seeks to the `line_stop`-nth line and returns the +/// new offset and the line index at that point. +/// +/// It returns an offset *past* the newline. +/// If `line` is already at or past `line_stop`, it returns immediately. +pub fn lines_fwd( + haystack: &[u8], + offset: usize, + line: CoordType, + line_stop: CoordType, +) -> (usize, CoordType) { + unsafe { + let beg = haystack.as_ptr(); + let end = beg.add(haystack.len()); + let it = beg.add(offset.min(haystack.len())); + let (it, line) = lines_fwd_raw(it, end, line, line_stop); + (it.offset_from_unsigned(beg), line) + } +} + +unsafe fn lines_fwd_raw( + beg: *const u8, + end: *const u8, + line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + #[cfg(target_arch = "x86_64")] + return unsafe { LINES_FWD_DISPATCH(beg, end, line, line_stop) }; + + #[cfg(target_arch = "aarch64")] + return unsafe { lines_fwd_neon(beg, end, line, line_stop) }; + + #[allow(unreachable_code)] + return unsafe { lines_fwd_fallback(beg, end, line, line_stop) }; +} + +unsafe fn lines_fwd_fallback( + mut beg: *const u8, + end: *const u8, + mut line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + unsafe { + if line < line_stop { + while !ptr::eq(beg, end) { + let c = *beg; + beg = beg.add(1); + if c == b'\n' { + line += 1; + if line == line_stop { + break; + } + } + } + } + (beg, line) + } +} + +#[cfg(target_arch = "x86_64")] +static mut LINES_FWD_DISPATCH: unsafe fn( + beg: *const u8, + end: *const u8, + line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) = lines_fwd_dispatch; + +#[cfg(target_arch = "x86_64")] +unsafe fn lines_fwd_dispatch( + beg: *const u8, + end: *const u8, + line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + let func = if is_x86_feature_detected!("avx2") { lines_fwd_avx2 } else { lines_fwd_fallback }; + unsafe { LINES_FWD_DISPATCH = func }; + unsafe { func(beg, end, line, line_stop) } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn lines_fwd_avx2( + mut beg: *const u8, + end: *const u8, + mut line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + unsafe { + use std::arch::x86_64::*; + + #[inline(always)] + unsafe fn horizontal_sum_i64(v: __m256i) -> i64 { + unsafe { + let hi = _mm256_extracti128_si256::<1>(v); + let lo = _mm256_castsi256_si128(v); + let sum = _mm_add_epi64(lo, hi); + let shuf = _mm_shuffle_epi32::<0b11_10_11_10>(sum); + let sum = _mm_add_epi64(sum, shuf); + _mm_cvtsi128_si64(sum) + } + } + + let lf = _mm256_set1_epi8(b'\n' as i8); + let mut remaining = end.offset_from_unsigned(beg); + + if line < line_stop { + // Unrolling the loop by 4x speeds things up by >3x. + // It allows us to accumulate matches before doing a single `vpsadbw`. + while remaining >= 128 { + let v1 = _mm256_loadu_si256(beg.add(0) as *const _); + let v2 = _mm256_loadu_si256(beg.add(32) as *const _); + let v3 = _mm256_loadu_si256(beg.add(64) as *const _); + let v4 = _mm256_loadu_si256(beg.add(96) as *const _); + + // `vpcmpeqb` leaves each comparison result byte as 0 or -1 (0xff). + // This allows us to accumulate the comparisons by subtracting them. + let mut sum = _mm256_setzero_si256(); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v1, lf)); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v2, lf)); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v3, lf)); + sum = _mm256_sub_epi8(sum, _mm256_cmpeq_epi8(v4, lf)); + + // Calculate the total number of matches in this chunk. + let sum = _mm256_sad_epu8(sum, _mm256_setzero_si256()); + let sum = horizontal_sum_i64(sum); + + let line_next = line + sum as CoordType; + if line_next >= line_stop { + break; + } + + beg = beg.add(128); + remaining -= 128; + line = line_next; + } + + while remaining >= 32 { + let v = _mm256_loadu_si256(beg as *const _); + let c = _mm256_cmpeq_epi8(v, lf); + + // If you ask an LLM, the best way to do this is + // to do a `vpmovmskb` followed by `popcnt`. + // One contemporary hardware that's a bad idea though. + let ones = _mm256_and_si256(c, _mm256_set1_epi8(0x01)); + let sum = _mm256_sad_epu8(ones, _mm256_setzero_si256()); + let sum = horizontal_sum_i64(sum); + + let line_next = line + sum as CoordType; + if line_next >= line_stop { + break; + } + + beg = beg.add(32); + remaining -= 32; + line = line_next; + } + } + + lines_fwd_fallback(beg, end, line, line_stop) + } +} + +#[cfg(target_arch = "aarch64")] +unsafe fn lines_fwd_neon( + mut beg: *const u8, + end: *const u8, + mut line: CoordType, + line_stop: CoordType, +) -> (*const u8, CoordType) { + unsafe { + use std::arch::aarch64::*; + + let lf = vdupq_n_u8(b'\n'); + let mut remaining = end.offset_from_unsigned(beg); + + if line < line_stop { + while remaining >= 64 { + let v1 = vld1q_u8(beg.add(0)); + let v2 = vld1q_u8(beg.add(16)); + let v3 = vld1q_u8(beg.add(32)); + let v4 = vld1q_u8(beg.add(48)); + + // `vceqq_u8` leaves each comparison result byte as 0 or -1 (0xff). + // This allows us to accumulate the comparisons by subtracting them. + let mut sum = vdupq_n_u8(0); + sum = vsubq_u8(sum, vceqq_u8(v1, lf)); + sum = vsubq_u8(sum, vceqq_u8(v2, lf)); + sum = vsubq_u8(sum, vceqq_u8(v3, lf)); + sum = vsubq_u8(sum, vceqq_u8(v4, lf)); + + let sum = vaddvq_u8(sum); + + let line_next = line + sum as CoordType; + if line_next >= line_stop { + break; + } + + beg = beg.add(64); + remaining -= 64; + line = line_next; + } + + while remaining >= 16 { + let v = vld1q_u8(beg); + let c = vceqq_u8(v, lf); + let c = vandq_u8(c, vdupq_n_u8(0x01)); + let sum = vaddvq_u8(c); + + let line_next = line + sum as CoordType; + if line_next >= line_stop { + break; + } + + beg = beg.add(16); + remaining -= 16; + line = line_next; + } + } + + lines_fwd_fallback(beg, end, line, line_stop) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::helpers::CoordType; + use crate::simd::test::*; + + #[test] + fn pseudo_fuzz() { + let text = generate_random_text(1024); + let lines = count_lines(&text); + let mut offset_rng = make_rng(); + let mut line_rng = make_rng(); + let mut line_distance_rng = make_rng(); + + for _ in 0..1000 { + let offset = offset_rng() % (text.len() + 1); + let line = line_rng() % 100; + let line_stop = line + line_distance_rng() % (lines + 1); + + let line = line as CoordType; + let line_stop = line_stop as CoordType; + + let expected = reference_lines_fwd(text.as_bytes(), offset, line, line_stop); + let actual = lines_fwd(text.as_bytes(), offset, line, line_stop); + + assert_eq!(expected, actual); + } + } + + fn reference_lines_fwd( + haystack: &[u8], + mut offset: usize, + mut line: CoordType, + line_stop: CoordType, + ) -> (usize, CoordType) { + if line < line_stop { + while offset < haystack.len() { + let c = haystack[offset]; + offset += 1; + if c == b'\n' { + line += 1; + if line == line_stop { + break; + } + } + } + } + (offset, line) + } +} diff --git a/src/simd/memrchr2.rs b/src/simd/memrchr2.rs deleted file mode 100644 index 97c3c11..0000000 --- a/src/simd/memrchr2.rs +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -//! `memchr`, but with two needles. - -use std::ptr; - -/// `memchr`, but with two needles. -/// -/// If no needle is found, 0 is returned. -/// Unlike `memchr2` (or `memrchr`), an offset PAST the hit is returned. -/// This is because this function is primarily used for -/// `ucd::newlines_backward`, which needs exactly that. -pub fn memrchr2(needle1: u8, needle2: u8, haystack: &[u8], offset: usize) -> Option { - unsafe { - let beg = haystack.as_ptr(); - let it = beg.add(offset.min(haystack.len())); - let it = memrchr2_raw(needle1, needle2, beg, it); - if it.is_null() { None } else { Some(it.offset_from_unsigned(beg)) } - } -} - -unsafe fn memrchr2_raw(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - return unsafe { MEMRCHR2_DISPATCH(needle1, needle2, beg, end) }; - - #[cfg(target_arch = "aarch64")] - return unsafe { memrchr2_neon(needle1, needle2, beg, end) }; - - #[allow(unreachable_code)] - return unsafe { memrchr2_fallback(needle1, needle2, beg, end) }; -} - -unsafe fn memrchr2_fallback( - needle1: u8, - needle2: u8, - beg: *const u8, - mut end: *const u8, -) -> *const u8 { - unsafe { - while !ptr::eq(end, beg) { - end = end.sub(1); - let ch = *end; - if ch == needle1 || needle2 == ch { - return end; - } - } - ptr::null() - } -} - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -static mut MEMRCHR2_DISPATCH: unsafe fn( - needle1: u8, - needle2: u8, - beg: *const u8, - end: *const u8, -) -> *const u8 = memrchr2_dispatch; - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -unsafe fn memrchr2_dispatch(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 { - let func = if is_x86_feature_detected!("avx2") { memrchr2_avx2 } else { memrchr2_fallback }; - unsafe { MEMRCHR2_DISPATCH = func }; - unsafe { func(needle1, needle2, beg, end) } -} - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -#[target_feature(enable = "avx2")] -unsafe fn memrchr2_avx2(needle1: u8, needle2: u8, beg: *const u8, mut end: *const u8) -> *const u8 { - unsafe { - #[cfg(target_arch = "x86")] - use std::arch::x86::*; - #[cfg(target_arch = "x86_64")] - use std::arch::x86_64::*; - - if end.offset_from_unsigned(beg) >= 32 { - let n1 = _mm256_set1_epi8(needle1 as i8); - let n2 = _mm256_set1_epi8(needle2 as i8); - - loop { - end = end.sub(32); - - let v = _mm256_loadu_si256(end as *const _); - let a = _mm256_cmpeq_epi8(v, n1); - let b = _mm256_cmpeq_epi8(v, n2); - let c = _mm256_or_si256(a, b); - let m = _mm256_movemask_epi8(c) as u32; - - if m != 0 { - return end.add(31 - m.leading_zeros() as usize); - } - - if end.offset_from_unsigned(beg) < 32 { - break; - } - } - } - - memrchr2_fallback(needle1, needle2, beg, end) - } -} - -#[cfg(target_arch = "aarch64")] -unsafe fn memrchr2_neon(needle1: u8, needle2: u8, beg: *const u8, mut end: *const u8) -> *const u8 { - unsafe { - use std::arch::aarch64::*; - - if end.offset_from_unsigned(beg) >= 16 { - let n1 = vdupq_n_u8(needle1); - let n2 = vdupq_n_u8(needle2); - - loop { - end = end.sub(16); - - let v = vld1q_u8(end as *const _); - let a = vceqq_u8(v, n1); - let b = vceqq_u8(v, n2); - let c = vorrq_u8(a, b); - - // https://community.arm.com/arm-community-blogs/b/servers-and-cloud-computing-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon - let m = vreinterpretq_u16_u8(c); - let m = vshrn_n_u16(m, 4); - let m = vreinterpret_u64_u8(m); - let m = vget_lane_u64(m, 0); - - if m != 0 { - return end.add(15 - (m.leading_zeros() as usize >> 2)); - } - - if end.offset_from_unsigned(beg) < 16 { - break; - } - } - } - - memrchr2_fallback(needle1, needle2, beg, end) - } -} - -#[cfg(test)] -mod tests { - use std::slice; - - use super::*; - use crate::sys; - - #[test] - fn test_empty() { - assert_eq!(memrchr2(b'a', b'b', b"", 0), None); - } - - #[test] - fn test_basic() { - let haystack = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; - let haystack = &haystack[..43]; - - assert_eq!(memrchr2(b'Q', b'P', haystack, 43), Some(42)); - assert_eq!(memrchr2(b'p', b'o', haystack, 43), Some(15)); - assert_eq!(memrchr2(b'a', b'b', haystack, 43), Some(1)); - assert_eq!(memrchr2(b'0', b'9', haystack, 43), None); - } - - // Test that it doesn't match before/after the start offset respectively. - #[test] - fn test_with_offset() { - let haystack = b"abcdefghabcdefghabcdefghabcdefghabcdefgh"; - - assert_eq!(memrchr2(b'h', b'g', haystack, 40), Some(39)); - assert_eq!(memrchr2(b'h', b'g', haystack, 39), Some(38)); - assert_eq!(memrchr2(b'a', b'b', haystack, 9), Some(8)); - assert_eq!(memrchr2(b'a', b'b', haystack, 1), Some(0)); - assert_eq!(memrchr2(b'a', b'b', haystack, 0), None); - } - - // Test memory access safety at page boundaries. - // The test is a success if it doesn't segfault. - #[test] - fn test_page_boundary() { - let page = unsafe { - const PAGE_SIZE: usize = 64 * 1024; // 64 KiB to cover many architectures. - - // 3 pages: uncommitted, committed, uncommitted - let ptr = sys::virtual_reserve(PAGE_SIZE * 3).unwrap(); - sys::virtual_commit(ptr.add(PAGE_SIZE), PAGE_SIZE).unwrap(); - slice::from_raw_parts_mut(ptr.add(PAGE_SIZE).as_ptr(), PAGE_SIZE) - }; - - page.fill(b'a'); - - // Same as above, but for memrchr2 (hence reversed). - assert_eq!(memrchr2(b'\0', b'\0', &page[page.len() - 10..], 10), None); - assert_eq!(memrchr2(b'\0', b'\0', &page[..40], 40), None); - } -} diff --git a/src/simd/mod.rs b/src/simd/mod.rs index 542d985..7f60ed4 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -3,10 +3,41 @@ //! Provides various high-throughput utilities. +pub mod lines_bwd; +pub mod lines_fwd; mod memchr2; -mod memrchr2; mod memset; +pub use lines_bwd::*; +pub use lines_fwd::*; pub use memchr2::*; -pub use memrchr2::*; pub use memset::*; + +#[cfg(test)] +mod test { + // Knuth's MMIX LCG + pub fn make_rng() -> impl FnMut() -> usize { + let mut state = 1442695040888963407u64; + move || { + state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + state as usize + } + } + + pub fn generate_random_text(len: usize) -> String { + const ALPHABET: &[u8; 20] = b"0123456789abcdef\n\n\n\n"; + + let mut rng = make_rng(); + let mut res = String::new(); + + for _ in 0..len { + res.push(ALPHABET[rng() % ALPHABET.len()] as char); + } + + res + } + + pub fn count_lines(text: &str) -> usize { + text.lines().count() + } +} diff --git a/src/tui.rs b/src/tui.rs index 87d2b4a..dd181a5 100644 --- a/src/tui.rs +++ b/src/tui.rs @@ -157,7 +157,7 @@ use crate::framebuffer::{Attributes, Framebuffer, INDEXED_COLORS_COUNT, IndexedC use crate::hash::*; use crate::helpers::*; use crate::input::{InputKeyMod, kbmod, vk}; -use crate::{apperr, arena_format, input, unicode}; +use crate::{apperr, arena_format, input, simd, unicode}; const ROOT_ID: u64 = 0x14057B7EF767814F; // Knuth's MMIX constant const SHIFT_TAB: InputKey = vk::TAB.with_modifiers(kbmod::SHIFT); @@ -2690,7 +2690,7 @@ impl<'a> Context<'a, '_> { } if single_line && !write.is_empty() { - let (end, _) = unicode::newlines_forward(write, 0, 0, 1); + let (end, _) = simd::lines_fwd(write, 0, 0, 1); write = unicode::strip_newline(&write[..end]); } if !write.is_empty() { diff --git a/src/unicode/measurement.rs b/src/unicode/measurement.rs index 4aac8ac..38e22ad 100644 --- a/src/unicode/measurement.rs +++ b/src/unicode/measurement.rs @@ -7,7 +7,6 @@ use super::Utf8Chars; use super::tables::*; use crate::document::ReadableDocument; use crate::helpers::{CoordType, Point}; -use crate::simd::{memchr2, memrchr2}; // On one hand it's disgusting that I wrote this as a global variable, but on the // other hand, this isn't a public library API, and it makes the code a lot cleaner, @@ -478,104 +477,6 @@ impl<'doc> MeasurementConfig<'doc> { } } -/// Seeks forward to the given line start. -/// -/// If given a piece of `text`, and assuming you're currently at `offset` which -/// is on the logical line `line`, this will seek forward until the logical line -/// `line_stop` is reached. For instance, if `line` is 0 and `line_stop` is 2, -/// it'll seek forward past 2 line feeds. -/// -/// This function always stops exactly past a line feed -/// and thus returns a position at the start of a line. -/// -/// # Warning -/// -/// If the end of `text` is hit before reaching `line_stop`, the function -/// will return an offset of `text.len()`, not at the start of a line. -/// -/// # Parameters -/// -/// * `text`: The text to search in. -/// * `offset`: The offset to start searching from. -/// * `line`: The current line. -/// * `line_stop`: The line to stop at. -/// -/// # Returns -/// -/// A tuple consisting of: -/// * The new offset. -/// * The line number that was reached. -pub fn newlines_forward( - text: &[u8], - mut offset: usize, - mut line: CoordType, - line_stop: CoordType, -) -> (usize, CoordType) { - // Leaving the cursor at the beginning of the current line when the limit - // is 0 makes this function behave identical to ucd_newlines_backward. - if line >= line_stop { - return newlines_backward(text, offset, line, line_stop); - } - - let len = text.len(); - offset = offset.min(len); - - loop { - // TODO: This code could be optimized by replacing memchr with manual line counting. - // - // If `line_stop` is very far away, we could accumulate newline counts horizontally - // in a AVX2 register (= 32 u8 slots). Then, every 256 bytes we compute the horizontal - // sum via `_mm256_sad_epu8` yielding us the newline count in the last block. - // - // We could also just use `_mm256_sad_epu8` on each fetch as-is. - offset = memchr2(b'\n', b'\n', text, offset); - if offset >= len { - break; - } - - offset += 1; - line += 1; - if line >= line_stop { - break; - } - } - - (offset, line) -} - -/// Seeks backward to the given line start. -/// -/// See [`newlines_forward`] for details. -/// This function does almost the same thing, but in reverse. -/// -/// # Warning -/// -/// In addition to the notes in [`newlines_forward`]: -/// -/// No matter what parameters are given, [`newlines_backward`] only returns an -/// offset at the start of a line. Put differently, even if `line == line_stop`, -/// it'll seek backward to the line start. -pub fn newlines_backward( - text: &[u8], - mut offset: usize, - mut line: CoordType, - line_stop: CoordType, -) -> (usize, CoordType) { - offset = offset.min(text.len()); - - loop { - offset = match memrchr2(b'\n', b'\n', text, offset) { - Some(i) => i, - None => return (0, line), - }; - if line <= line_stop { - // +1: Past the newline, at the start of the current line. - return (offset + 1, line); - } - line -= 1; - } -} - /// Returns an offset past a newline. /// /// If `offset` is right in front of a newline, @@ -1152,23 +1053,6 @@ mod test { ); } - #[test] - fn test_newlines_and_strip() { - // Offset line 0: 0 - // Offset line 1: 6 - // Offset line 2: 13 - // Offset line 3: 18 - let text = "line1\nline2\r\nline3".as_bytes(); - - assert_eq!(newlines_forward(text, 0, 0, 2), (13, 2)); - assert_eq!(newlines_forward(text, 0, 0, 0), (0, 0)); - assert_eq!(newlines_forward(text, 100, 2, 100), (18, 2)); - - assert_eq!(newlines_backward(text, 18, 2, 1), (6, 1)); - assert_eq!(newlines_backward(text, 18, 2, 0), (0, 0)); - assert_eq!(newlines_backward(text, 100, 2, 1), (6, 1)); - } - #[test] fn test_strip_newline() { assert_eq!(strip_newline(b"hello\n"), b"hello");