mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 13:29:12 +00:00
add partial backwards merge
This commit is contained in:
parent
ea0063b992
commit
eb8c91775f
1 changed files with 288 additions and 13 deletions
|
@ -76,6 +76,215 @@ fn quadsort_direct(
|
|||
// ================ Unbalanced Merges =========================================
|
||||
|
||||
/// Merges a full left block with a smaller than block size right chunk.
|
||||
/// The merge goes from tail to head.
|
||||
fn partial_backwards_merge(
|
||||
array: [*]u8,
|
||||
len: usize,
|
||||
swap: [*]u8,
|
||||
swap_len: usize,
|
||||
block_len: usize,
|
||||
cmp_data: Opaque,
|
||||
cmp: CompareFn,
|
||||
element_width: usize,
|
||||
copy: CopyFn,
|
||||
) void {
|
||||
std.debug.assert(swap_len >= block_len);
|
||||
|
||||
if (len == block_len) {
|
||||
// Just a single block, already done.
|
||||
return;
|
||||
}
|
||||
|
||||
var left_tail = array + (block_len - 1) * element_width;
|
||||
var dest_tail = array + (len - 1) * element_width;
|
||||
|
||||
if (compare(cmp, cmp_data, left_tail, left_tail + element_width) != GT) {
|
||||
// Lucky case, blocks happen to be sorted.
|
||||
return;
|
||||
}
|
||||
|
||||
const right_len = len - block_len;
|
||||
if (len <= swap_len and right_len >= 64) {
|
||||
// Large remaining merge and we have enough space to just do it in swap.
|
||||
|
||||
cross_merge(swap, array, block_len, right_len, cmp_data, cmp, element_width, copy);
|
||||
|
||||
@memcpy(array[0..(element_width * len)], swap[0..(element_width * len)]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@memcpy(swap[0..(element_width * right_len)], (array + block_len * element_width)[0..(element_width * right_len)]);
|
||||
|
||||
var right_tail = swap + (right_len - 1) * element_width;
|
||||
|
||||
// For backards, we first try to do really large chunks, of 16 elements.
|
||||
outer: while (@intFromPtr(left_tail) > @intFromPtr(array + 16 * element_width) and @intFromPtr(right_tail) > @intFromPtr(swap + 16 * element_width)) {
|
||||
while (compare(cmp, cmp_data, left_tail, right_tail - 15 * element_width) != GT) {
|
||||
inline for (0..16) |_| {
|
||||
copy(dest_tail, right_tail);
|
||||
dest_tail -= element_width;
|
||||
right_tail -= element_width;
|
||||
}
|
||||
if (@intFromPtr(right_tail) <= @intFromPtr(swap + 16 * element_width))
|
||||
break :outer;
|
||||
}
|
||||
while (compare(cmp, cmp_data, left_tail - 15 * element_width, right_tail) == GT) {
|
||||
inline for (0..16) |_| {
|
||||
copy(dest_tail, left_tail);
|
||||
dest_tail -= element_width;
|
||||
left_tail -= element_width;
|
||||
}
|
||||
if (@intFromPtr(left_tail) <= @intFromPtr(array + 16 * element_width))
|
||||
break :outer;
|
||||
}
|
||||
// Attempt to deal with the rest of the chunk in groups of 2.
|
||||
var loops: usize = 8;
|
||||
while (true) {
|
||||
if (compare(cmp, cmp_data, left_tail, right_tail - element_width) != GT) {
|
||||
inline for (0..2) |_| {
|
||||
copy(dest_tail, right_tail);
|
||||
dest_tail -= element_width;
|
||||
right_tail -= element_width;
|
||||
}
|
||||
} else if (compare(cmp, cmp_data, left_tail - element_width, right_tail) == GT) {
|
||||
inline for (0..2) |_| {
|
||||
copy(dest_tail, left_tail);
|
||||
dest_tail -= element_width;
|
||||
left_tail -= element_width;
|
||||
}
|
||||
} else {
|
||||
// Couldn't move two elements, do a cross swap and continue.
|
||||
const lte = compare(cmp, cmp_data, left_tail, right_tail) != GT;
|
||||
var x = if (lte) element_width else 0;
|
||||
var not_x = if (!lte) element_width else 0;
|
||||
dest_tail -= element_width;
|
||||
copy(dest_tail + x, right_tail);
|
||||
right_tail -= element_width;
|
||||
copy(dest_tail + not_x, left_tail);
|
||||
left_tail -= element_width;
|
||||
dest_tail -= element_width;
|
||||
|
||||
tail_branchless_merge(&dest_tail, &left_tail, &right_tail, cmp_data, cmp, element_width, copy);
|
||||
}
|
||||
|
||||
loops -= 1;
|
||||
if (loops == 0)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// For rest of tail, attempt to merge 2 elements a time from tail to head.
|
||||
while (@intFromPtr(right_tail) > @intFromPtr(swap) + element_width and @intFromPtr(left_tail) > @intFromPtr(array) + element_width) {
|
||||
// Note: I am not sure how to get the same generation as the original C.
|
||||
// This implementation has an extra function call here.
|
||||
// The C use `goto` to implement the two tail recursive functions below inline.
|
||||
const break_loop = partial_forward_merge_right_tail_2(&dest_tail, &array, &left_tail, &swap, &right_tail, cmp_data, cmp, element_width, copy);
|
||||
if (break_loop)
|
||||
break;
|
||||
|
||||
// Couldn't move two elements, do a cross swap and continue.
|
||||
const lte = compare(cmp, cmp_data, left_tail, right_tail) != GT;
|
||||
var x = if (lte) element_width else 0;
|
||||
var not_x = if (!lte) element_width else 0;
|
||||
dest_tail -= element_width;
|
||||
copy(dest_tail + x, right_tail);
|
||||
right_tail -= element_width;
|
||||
copy(dest_tail + not_x, left_tail);
|
||||
left_tail -= element_width;
|
||||
dest_tail -= element_width;
|
||||
|
||||
tail_branchless_merge(&dest_tail, &left_tail, &right_tail, cmp_data, cmp, element_width, copy);
|
||||
}
|
||||
|
||||
// Deal with tail.
|
||||
while (@intFromPtr(right_tail) >= @intFromPtr(swap) and @intFromPtr(left_tail) >= @intFromPtr(array)) {
|
||||
tail_branchless_merge(&dest_tail, &left_tail, &right_tail, cmp_data, cmp, element_width, copy);
|
||||
}
|
||||
while (@intFromPtr(right_tail) >= @intFromPtr(swap)) {
|
||||
copy(dest_tail, right_tail);
|
||||
dest_tail -= element_width;
|
||||
right_tail -= element_width;
|
||||
}
|
||||
}
|
||||
|
||||
// The following two functions are exactly the same but with the if blocks swapped.
|
||||
// They hot loop on one side until it fails, then switch to the other list.
|
||||
|
||||
fn partial_forward_merge_right_tail_2(
|
||||
dest: *[*]u8,
|
||||
left_head: *const [*]u8,
|
||||
left_tail: *[*]u8,
|
||||
right_head: *const [*]u8,
|
||||
right_tail: *[*]u8,
|
||||
cmp_data: Opaque,
|
||||
cmp: CompareFn,
|
||||
element_width: usize,
|
||||
copy: CopyFn,
|
||||
) bool {
|
||||
if (compare(cmp, cmp_data, left_tail.*, right_tail.* - element_width) != GT) {
|
||||
inline for (0..2) |_| {
|
||||
copy(dest.*, right_tail.*);
|
||||
dest.* -= element_width;
|
||||
right_tail.* -= element_width;
|
||||
}
|
||||
if (@intFromPtr(right_tail.*) > @intFromPtr(right_head.*) + element_width) {
|
||||
return @call(.always_tail, partial_forward_merge_right_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (compare(cmp, cmp_data, left_tail.* - element_width, right_tail.*) == GT) {
|
||||
inline for (0..2) |_| {
|
||||
copy(dest.*, left_tail.*);
|
||||
dest.* -= element_width;
|
||||
left_tail.* -= element_width;
|
||||
}
|
||||
if (@intFromPtr(left_tail.*) > @intFromPtr(left_head.*) + element_width) {
|
||||
return @call(.always_tail, partial_forward_merge_left_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
fn partial_forward_merge_left_tail_2(
|
||||
dest: *[*]u8,
|
||||
left_head: *const [*]u8,
|
||||
left_tail: *[*]u8,
|
||||
right_head: *const [*]u8,
|
||||
right_tail: *[*]u8,
|
||||
cmp_data: Opaque,
|
||||
cmp: CompareFn,
|
||||
element_width: usize,
|
||||
copy: CopyFn,
|
||||
) bool {
|
||||
if (compare(cmp, cmp_data, left_tail.* - element_width, right_tail.*) == GT) {
|
||||
inline for (0..2) |_| {
|
||||
copy(dest.*, left_tail.*);
|
||||
dest.* -= element_width;
|
||||
left_tail.* -= element_width;
|
||||
}
|
||||
if (@intFromPtr(left_tail.*) > @intFromPtr(left_head.*) + element_width) {
|
||||
return @call(.always_tail, partial_forward_merge_left_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (compare(cmp, cmp_data, left_tail.*, right_tail.* - element_width) != GT) {
|
||||
inline for (0..2) |_| {
|
||||
copy(dest.*, right_tail.*);
|
||||
dest.* -= element_width;
|
||||
right_tail.* -= element_width;
|
||||
}
|
||||
if (@intFromPtr(right_tail.*) > @intFromPtr(right_head.*) + element_width) {
|
||||
return @call(.always_tail, partial_forward_merge_right_tail_2, .{ dest, left_head, left_tail, right_head, right_tail, cmp_data, cmp, element_width, copy });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Merges a full left block with a smaller than block size right chunk.
|
||||
/// The merge goes from head to tail.
|
||||
fn partial_forward_merge(
|
||||
array: [*]u8,
|
||||
len: usize,
|
||||
|
@ -98,7 +307,7 @@ fn partial_forward_merge(
|
|||
var right_tail = array + (len - 1) * element_width;
|
||||
|
||||
if (compare(cmp, cmp_data, right_head - element_width, right_head) != GT) {
|
||||
// Luck case, blocks happen to be sorted.
|
||||
// Lucky case, blocks happen to be sorted.
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -147,9 +356,9 @@ fn partial_forward_merge(
|
|||
fn partial_forward_merge_right_head_2(
|
||||
dest: *[*]u8,
|
||||
left_head: *[*]u8,
|
||||
left_tail: *[*]u8,
|
||||
left_tail: *const [*]u8,
|
||||
right_head: *[*]u8,
|
||||
right_tail: *[*]u8,
|
||||
right_tail: *const [*]u8,
|
||||
cmp_data: Opaque,
|
||||
cmp: CompareFn,
|
||||
element_width: usize,
|
||||
|
@ -183,9 +392,9 @@ fn partial_forward_merge_right_head_2(
|
|||
fn partial_forward_merge_left_head_2(
|
||||
dest: *[*]u8,
|
||||
left_head: *[*]u8,
|
||||
left_tail: *[*]u8,
|
||||
left_tail: *const [*]u8,
|
||||
right_head: *[*]u8,
|
||||
right_tail: *[*]u8,
|
||||
right_tail: *const [*]u8,
|
||||
cmp_data: Opaque,
|
||||
cmp: CompareFn,
|
||||
element_width: usize,
|
||||
|
@ -216,6 +425,80 @@ fn partial_forward_merge_left_head_2(
|
|||
return false;
|
||||
}
|
||||
|
||||
test "partial_backwards_merge" {
|
||||
{
|
||||
const expected = [10]i64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
|
||||
|
||||
var arr: [10]i64 = undefined;
|
||||
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
|
||||
var swap: [10]i64 = undefined;
|
||||
var swap_ptr = @as([*]u8, @ptrCast(&swap[0]));
|
||||
|
||||
arr = [10]i64{ 3, 4, 5, 6, 7, 8, 1, 2, 9, 10 };
|
||||
partial_backwards_merge(arr_ptr, 10, swap_ptr, 10, 6, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
|
||||
try testing.expectEqual(arr, expected);
|
||||
|
||||
arr = [10]i64{ 2, 4, 6, 8, 9, 10, 1, 3, 5, 7 };
|
||||
partial_backwards_merge(arr_ptr, 10, swap_ptr, 10, 6, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
|
||||
try testing.expectEqual(arr, expected);
|
||||
|
||||
arr = [10]i64{ 1, 2, 3, 4, 5, 6, 8, 9, 10, 7 };
|
||||
partial_backwards_merge(arr_ptr, 10, swap_ptr, 10, 9, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
|
||||
try testing.expectEqual(arr, expected);
|
||||
|
||||
arr = [10]i64{ 1, 2, 4, 5, 6, 8, 9, 3, 7, 10 };
|
||||
partial_backwards_merge(arr_ptr, 10, swap_ptr, 9, 7, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
|
||||
try testing.expectEqual(arr, expected);
|
||||
}
|
||||
|
||||
{
|
||||
var expected: [64]i64 = undefined;
|
||||
for (0..64) |i| {
|
||||
expected[i] = @intCast(i + 1);
|
||||
}
|
||||
|
||||
var arr: [64]i64 = undefined;
|
||||
var arr_ptr = @as([*]u8, @ptrCast(&arr[0]));
|
||||
var swap: [64]i64 = undefined;
|
||||
var swap_ptr = @as([*]u8, @ptrCast(&swap[0]));
|
||||
|
||||
// chunks
|
||||
for (0..16) |i| {
|
||||
arr[i] = @intCast(i + 17);
|
||||
}
|
||||
for (0..16) |i| {
|
||||
arr[i + 16] = @intCast(i + 49);
|
||||
}
|
||||
for (0..16) |i| {
|
||||
arr[i + 32] = @intCast(i + 1);
|
||||
}
|
||||
for (0..16) |i| {
|
||||
arr[i + 48] = @intCast(i + 33);
|
||||
}
|
||||
partial_backwards_merge(arr_ptr, 64, swap_ptr, 64, 32, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
|
||||
try testing.expectEqual(arr, expected);
|
||||
|
||||
// chunks with break
|
||||
for (0..16) |i| {
|
||||
arr[i] = @intCast(i + 17);
|
||||
}
|
||||
for (0..16) |i| {
|
||||
arr[i + 32] = @intCast(i + 1);
|
||||
}
|
||||
for (0..16) |i| {
|
||||
arr[i + 16] = @intCast(i + 49);
|
||||
}
|
||||
for (0..16) |i| {
|
||||
arr[i + 48] = @intCast(i + 34);
|
||||
}
|
||||
arr[16] = 33;
|
||||
arr[63] = 49;
|
||||
|
||||
partial_backwards_merge(arr_ptr, 64, swap_ptr, 64, 32, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy);
|
||||
try testing.expectEqual(arr, expected);
|
||||
}
|
||||
}
|
||||
|
||||
test "partial_forward_merge" {
|
||||
const expected = [10]i64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
|
||||
|
||||
|
@ -329,8 +612,6 @@ fn cross_merge(
|
|||
if (@intFromPtr(left_tail) - @intFromPtr(left_head) > 8 * element_width) {
|
||||
// 8 elements all less than or equal to and can be moved together.
|
||||
while (compare(cmp, cmp_data, left_head + 7 * element_width, right_head) != GT) {
|
||||
// TODO: Should this actually be a memcpy?
|
||||
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
|
||||
inline for (0..8) |_| {
|
||||
copy(dest_head, left_head);
|
||||
dest_head += element_width;
|
||||
|
@ -343,8 +624,6 @@ fn cross_merge(
|
|||
// Attempt to do the same from the tail.
|
||||
// 8 elements all greater than and can be moved together.
|
||||
while (compare(cmp, cmp_data, left_tail - 7 * element_width, right_tail) == GT) {
|
||||
// TODO: Should this actually be a memcpy?
|
||||
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
|
||||
inline for (0..8) |_| {
|
||||
copy(dest_tail, left_tail);
|
||||
dest_tail -= element_width;
|
||||
|
@ -359,8 +638,6 @@ fn cross_merge(
|
|||
if (@intFromPtr(right_tail) - @intFromPtr(right_head) > 8 * element_width) {
|
||||
// left greater than 8 elements right and can be moved together.
|
||||
while (compare(cmp, cmp_data, left_head, right_head + 7 * element_width) == GT) {
|
||||
// TODO: Should this actually be a memcpy?
|
||||
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
|
||||
inline for (0..8) |_| {
|
||||
copy(dest_head, right_head);
|
||||
dest_head += element_width;
|
||||
|
@ -373,8 +650,6 @@ fn cross_merge(
|
|||
// Attempt to do the same from the tail.
|
||||
// left less than or equalt to 8 elements right and can be moved together.
|
||||
while (compare(cmp, cmp_data, left_tail, right_tail - 7 * element_width) != GT) {
|
||||
// TODO: Should this actually be a memcpy?
|
||||
// Memcpy won't know the size until runtime but it is 1 call instead of 8.
|
||||
inline for (0..8) |_| {
|
||||
copy(dest_tail, right_tail);
|
||||
dest_tail -= element_width;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue