add tailswap for 31 or less elements

This commit is contained in:
Brendan Hansknecht 2024-07-23 15:51:53 -07:00
parent 33e6dabeba
commit a7823c2164
No known key found for this signature in database
GPG key ID: 0EA784685083E75B

View file

@ -69,7 +69,50 @@ fn quadsort_direct(
}
// ================ Small Arrays ==============================================
// Below are functions for sorting under 31 element arrays.
// Below are functions for sorting under 32 element arrays.
/// Uses swap space to sort the tail of an array.
/// The array should be under 32 elements in length.
fn tail_swap(
array: [*]u8,
len: usize,
swap: [*]u8,
cmp_data: Opaque,
cmp: CompareFn,
element_width: usize,
copy: CopyFn,
) void {
std.debug.assert(len < 32);
if (len < 8) {
tiny_sort(array, len, swap, cmp_data, cmp, element_width, copy);
return;
}
const half1 = len / 2;
const quad1 = half1 / 2;
const quad2 = half1 - quad1;
const half2 = len - half1;
const quad3 = half2 / 2;
const quad4 = half2 - quad3;
var arr_ptr = array;
tail_swap(arr_ptr, quad1, swap, cmp_data, cmp, element_width, copy);
arr_ptr += quad1 * element_width;
tail_swap(arr_ptr, quad2, swap, cmp_data, cmp, element_width, copy);
arr_ptr += quad2 * element_width;
tail_swap(arr_ptr, quad3, swap, cmp_data, cmp, element_width, copy);
arr_ptr += quad3 * element_width;
tail_swap(arr_ptr, quad4, swap, cmp_data, cmp, element_width, copy);
if (compare(cmp, cmp_data, array + (quad1 - 1) * element_width, array + quad1 * element_width) != GT and compare(cmp, cmp_data, array + (half1 - 1) * element_width, array + half1 * element_width) != GT and compare(cmp, cmp_data, arr_ptr - 1 * element_width, arr_ptr) != GT) {
return;
}
parity_merge(swap, array, quad1, quad2, cmp_data, cmp, element_width, copy);
parity_merge(swap + half1 * element_width, array + half1 * element_width, quad3, quad4, cmp_data, cmp, element_width, copy);
parity_merge(array, swap, half1, half2, cmp_data, cmp, element_width, copy);
}
/// Merges two neighboring sorted arrays into dest.
/// Left must be equal to or 1 smaller than right.
@ -106,6 +149,27 @@ fn parity_merge(
tail_branchless_merge(&dest_tail, &left_tail, &right_tail, cmp_data, cmp, element_width, copy);
}
test "tail_swap" {
var swap: [31]i64 = undefined;
var swap_ptr = @as([*]u8, @ptrCast(&swap[0]));
var arr: [31]i64 = undefined;
var expected: [31]i64 = undefined;
for (0..31) |i| {
arr[i] = @intCast(i + 1);
expected[i] = @intCast(i + 1);
}
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
for (0..10) |seed| {
var rng = std.rand.DefaultPrng.init(seed);
rng.random().shuffle(i64, arr[0..]);
tail_swap(arr_ptr, 31, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
try testing.expectEqual(arr, expected);
}
}
test "parity_merge" {
{
var dest: [8]i64 = undefined;