add partial backwards merge

This commit is contained in:
Brendan Hansknecht 2024-07-23 22:10:47 -07:00
parent ea0063b992
commit eb8c91775f
No known key found for this signature in database
GPG key ID: 0EA784685083E75B

View file

@ -76,6 +76,215 @@ fn quadsort_direct(
// ================ Unbalanced Merges =========================================
/// Merges a full left block with a smaller than block size right chunk.
/// The merge goes from tail to head.
fn partial_backwards_merge(
array: [*]u8,
len: usize,
swap: [*]u8,
swap_len: usize,
block_len: usize,
cmp_data: Opaque,
cmp: CompareFn,
element_width: usize,
copy: CopyFn,
) void {
std.debug.assert(swap_len >= block_len);
if (len == block_len) {
// Just a single block, already done.
return;
}
var left_tail = array + (block_len - 1) * element_width;
var dest_tail = array + (len - 1) * element_width;
if (compare(cmp, cmp_data, left_tail, left_tail + element_width) != GT) {
// Lucky case, blocks happen to be sorted.
return;
}
const right_len = len - block_len;
if (len <= swap_len and right_len >= 64) {
// Large remaining merge and we have enough space to just do it in swap.
cross_merge(swap, array, block_len, right_len, cmp_data, cmp, element_width, copy);
@memcpy(array[0..(element_width * len)], swap[0..(element_width * len)]);
return;
}
@memcpy(swap[0..(element_width * right_len)], (array + block_len * element_width)[0..(element_width * right_len)]);
var right_tail = swap + (right_len - 1) * element_width;
// For backards, we first try to do really large chunks, of 16 elements.
outer: while (@intFromPtr(left_tail) > @intFromPtr(array + 16 * element_width) and @intFromPtr(right_tail) > @intFromPtr(swap + 16 * element_width)) {
while (compare(cmp, cmp_data, left_tail, right_tail - 15 * element_width) != GT) {
inline for (0..16) |_| {
copy(dest_tail, right_tail);
dest_tail -= element_width;
right_tail -= element_width;
}
if (@intFromPtr(right_tail) <= @intFromPtr(swap + 16 * element_width))
break :outer;
}
while (compare(cmp, cmp_data, left_tail - 15 * element_width, right_tail) == GT) {
inline for (0..16) |_| {
copy(dest_tail, left_tail);
dest_tail -= element_width;
left_tail -= element_width;
}
if (@intFromPtr(left_tail) <= @intFromPtr(array + 16 * element_width))
break :outer;
}
// Attempt to deal with the rest of the chunk in groups of 2.
var loops: usize = 8;
while (true) {
if (compare(cmp, cmp_data, left_tail, right_tail - element_width) != GT) {
inline for (0..2) |_| {
copy(dest_tail, right_tail);
dest_tail -= element_width;
right_tail -= element_width;
}
} else if (compare(cmp, cmp_data, left_tail - element_width, right_tail) == GT) {
inline for (0..2) |_| {
copy(dest_tail, left_tail);
dest_tail -= element_width;
left_tail -= element_width;
}
} else {
// Couldn't move two elements, do a cross swap and continue.
const lte = compare(cmp, cmp_data, left_tail, right_tail) != GT;
var x = if (lte) element_width else 0;
var not_x = if (!lte) element_width else 0;
dest_tail -= element_width;
copy(dest_tail + x, right_tail);
right_tail -= element_width;
copy(dest_tail + not_x, left_tail);
left_tail -= element_width;
dest_tail -= element_width;
tail_branchless_merge(&dest_tail, &left_tail, &right_tail, cmp_data, cmp, element_width, copy);
}
loops -= 1;
if (loops == 0)
break;
}
}
// For rest of tail, attempt to merge 2 elements a time from tail to head.
while (@intFromPtr(right_tail) > @intFromPtr(swap) + element_width and @intFromPtr(left_tail) > @intFromPtr(array) + element_width) {
// Note: I am not sure how to get the same generation as the original C.
// This implementation has an extra function call here.
// The C use `goto` to implement the two tail recursive functions below inline.
const break_loop = partial_forward_merge_right_tail_2(&dest_tail, &array, &left_tail, &swap, &right_tail, cmp_data, cmp, element_width, copy);
if (break_loop)
break;
// Couldn't move two elements, do a cross swap and continue.
const lte = compare(cmp, cmp_data, left_tail, right_tail) != GT;
var x = if (lte) element_width else 0;
var not_x = if (!lte) element_width else 0;
dest_tail -= element_width;
copy(dest_tail + x, right_tail);
right_tail -= element_width;
copy(dest_tail + not_x, left_tail);
left_tail -= element_width;
dest_tail -= element_width;
tail_branchless_merge(&dest_tail, &left_tail, &right_tail, cmp_data, cmp, element_width, copy);
}
// Deal with tail.
while (@intFromPtr(right_tail) >= @intFromPtr(swap) and @intFromPtr(left_tail) >= @intFromPtr(array)) {
tail_branchless_merge(&dest_tail, &left_tail, &right_tail, cmp_data, cmp, element_width, copy);
}
while (@intFromPtr(right_tail) >= @intFromPtr(swap)) {
copy(dest_tail, right_tail);
dest_tail -= element_width;
right_tail -= element_width;
}
}
// The following two functions are exactly the same but with the if blocks swapped.
// They hot loop on one side until it fails, then switch to the other list.
fn partial_forward_merge_right_tail_2(
dest: *[*]u8,
left_head: *const [*]u8,
left_tail: *[*]u8,
right_head: *const [*]u8,
right_tail: *[*]u8,
cmp_data: Opaque,
cmp: CompareFn,
element_width: usize,
copy: CopyFn,
) bool {
if (compare(cmp, cmp_data, left_tail.*, right_tail.* - element_width) != GT) {
inline for (0..2) |_| {
copy(dest.*, right_tail.*);
dest.* -= element_width;
right_tail.* -= element_width;
}
if (@intFromPtr(right_tail.*) > @intFromPtr(right_head.*) + element_width) {
return @call(.always_tail, partial_forward_merge_right_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
}
return true;
}
if (compare(cmp, cmp_data, left_tail.* - element_width, right_tail.*) == GT) {
inline for (0..2) |_| {
copy(dest.*, left_tail.*);
dest.* -= element_width;
left_tail.* -= element_width;
}
if (@intFromPtr(left_tail.*) > @intFromPtr(left_head.*) + element_width) {
return @call(.always_tail, partial_forward_merge_left_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
}
return true;
}
return false;
}
fn partial_forward_merge_left_tail_2(
dest: *[*]u8,
left_head: *const [*]u8,
left_tail: *[*]u8,
right_head: *const [*]u8,
right_tail: *[*]u8,
cmp_data: Opaque,
cmp: CompareFn,
element_width: usize,
copy: CopyFn,
) bool {
if (compare(cmp, cmp_data, left_tail.* - element_width, right_tail.*) == GT) {
inline for (0..2) |_| {
copy(dest.*, left_tail.*);
dest.* -= element_width;
left_tail.* -= element_width;
}
if (@intFromPtr(left_tail.*) > @intFromPtr(left_head.*) + element_width) {
return @call(.always_tail, partial_forward_merge_left_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
}
return true;
}
if (compare(cmp, cmp_data, left_tail.*, right_tail.* - element_width) != GT) {
inline for (0..2) |_| {
copy(dest.*, right_tail.*);
dest.* -= element_width;
right_tail.* -= element_width;
}
if (@intFromPtr(right_tail.*) > @intFromPtr(right_head.*) + element_width) {
return @call(.always_tail, partial_forward_merge_right_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
}
return true;
}
return false;
}
/// Merges a full left block with a smaller than block size right chunk.
/// The merge goes from head to tail.
fn partial_forward_merge(
array: [*]u8,
len: usize,
@ -98,7 +307,7 @@ fn partial_forward_merge(
var right_tail = array + (len - 1) * element_width;
if (compare(cmp, cmp_data, right_head - element_width, right_head) != GT) {
// Luck case, blocks happen to be sorted.
// Lucky case, blocks happen to be sorted.
return;
}
@ -147,9 +356,9 @@ fn partial_forward_merge(
fn partial_forward_merge_right_head_2(
dest: *[*]u8,
left_head: *[*]u8,
left_tail: *[*]u8,
left_tail: *const [*]u8,
right_head: *[*]u8,
right_tail: *[*]u8,
right_tail: *const [*]u8,
cmp_data: Opaque,
cmp: CompareFn,
element_width: usize,
@ -183,9 +392,9 @@ fn partial_forward_merge_right_head_2(
fn partial_forward_merge_left_head_2(
dest: *[*]u8,
left_head: *[*]u8,
left_tail: *[*]u8,
left_tail: *const [*]u8,
right_head: *[*]u8,
right_tail: *[*]u8,
right_tail: *const [*]u8,
cmp_data: Opaque,
cmp: CompareFn,
element_width: usize,
@ -216,6 +425,80 @@ fn partial_forward_merge_left_head_2(
return false;
}
test "partial_backwards_merge" {
{
const expected = [10]i64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
var arr: [10]i64 = undefined;
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
var swap: [10]i64 = undefined;
var swap_ptr = @as([*]u8, @ptrCast(&swap[0]));
arr = [10]i64{ 3, 4, 5, 6, 7, 8, 1, 2, 9, 10 };
partial_backwards_merge(arr_ptr, 10, swap_ptr, 10, 6, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
try testing.expectEqual(arr, expected);
arr = [10]i64{ 2, 4, 6, 8, 9, 10, 1, 3, 5, 7 };
partial_backwards_merge(arr_ptr, 10, swap_ptr, 10, 6, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
try testing.expectEqual(arr, expected);
arr = [10]i64{ 1, 2, 3, 4, 5, 6, 8, 9, 10, 7 };
partial_backwards_merge(arr_ptr, 10, swap_ptr, 10, 9, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
try testing.expectEqual(arr, expected);
arr = [10]i64{ 1, 2, 4, 5, 6, 8, 9, 3, 7, 10 };
partial_backwards_merge(arr_ptr, 10, swap_ptr, 9, 7, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
try testing.expectEqual(arr, expected);
}
{
var expected: [64]i64 = undefined;
for (0..64) |i| {
expected[i] = @intCast(i + 1);
}
var arr: [64]i64 = undefined;
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
var swap: [64]i64 = undefined;
var swap_ptr = @as([*]u8, @ptrCast(&swap[0]));
// chunks
for (0..16) |i| {
arr[i] = @intCast(i + 17);
}
for (0..16) |i| {
arr[i + 16] = @intCast(i + 49);
}
for (0..16) |i| {
arr[i + 32] = @intCast(i + 1);
}
for (0..16) |i| {
arr[i + 48] = @intCast(i + 33);
}
partial_backwards_merge(arr_ptr, 64, swap_ptr, 64, 32, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
try testing.expectEqual(arr, expected);
// chunks with break
for (0..16) |i| {
arr[i] = @intCast(i + 17);
}
for (0..16) |i| {
arr[i + 32] = @intCast(i + 1);
}
for (0..16) |i| {
arr[i + 16] = @intCast(i + 49);
}
for (0..16) |i| {
arr[i + 48] = @intCast(i + 34);
}
arr[16] = 33;
arr[63] = 49;
partial_backwards_merge(arr_ptr, 64, swap_ptr, 64, 32, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
try testing.expectEqual(arr, expected);
}
}
test "partial_forward_merge" {
const expected = [10]i64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
@ -329,8 +612,6 @@ fn cross_merge(
if (@intFromPtr(left_tail) - @intFromPtr(left_head) > 8 * element_width) {
// 8 elements all less than or equal to and can be moved together.
while (compare(cmp, cmp_data, left_head + 7 * element_width, right_head) != GT) {
// TODO: Should this actually be a memcpy?
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
inline for (0..8) |_| {
copy(dest_head, left_head);
dest_head += element_width;
@ -343,8 +624,6 @@ fn cross_merge(
// Attempt to do the same from the tail.
// 8 elements all greater than and can be moved together.
while (compare(cmp, cmp_data, left_tail - 7 * element_width, right_tail) == GT) {
// TODO: Should this actually be a memcpy?
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
inline for (0..8) |_| {
copy(dest_tail, left_tail);
dest_tail -= element_width;
@ -359,8 +638,6 @@ fn cross_merge(
if (@intFromPtr(right_tail) - @intFromPtr(right_head) > 8 * element_width) {
// left greater than 8 elements right and can be moved together.
while (compare(cmp, cmp_data, left_head, right_head + 7 * element_width) == GT) {
// TODO: Should this actually be a memcpy?
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
inline for (0..8) |_| {
copy(dest_head, right_head);
dest_head += element_width;
@ -373,8 +650,6 @@ fn cross_merge(
// Attempt to do the same from the tail.
// left less than or equalt to 8 elements right and can be moved together.
while (compare(cmp, cmp_data, left_tail, right_tail - 7 * element_width) != GT) {
// TODO: Should this actually be a memcpy?
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
inline for (0..8) |_| {
copy(dest_tail, right_tail);
dest_tail -= element_width;