From 52af8c588e2ea99b92d9bf579d3b74ce9cd909c5 Mon Sep 17 00:00:00 2001 From: Brendan Hansknecht Date: Mon, 22 Jul 2024 17:47:53 -0700 Subject: [PATCH] start adding the basis for quadsort for blitsort --- crates/compiler/builtins/bitcode/src/list.zig | 66 +--- crates/compiler/builtins/bitcode/src/sort.zig | 293 ++++++++++++++++++ 2 files changed, 300 insertions(+), 59 deletions(-) create mode 100644 crates/compiler/builtins/bitcode/src/sort.zig diff --git a/crates/compiler/builtins/bitcode/src/list.zig b/crates/compiler/builtins/bitcode/src/list.zig index ce82aa4c4e..f61bb3ff19 100644 --- a/crates/compiler/builtins/bitcode/src/list.zig +++ b/crates/compiler/builtins/bitcode/src/list.zig @@ -1,6 +1,7 @@ const std = @import("std"); const utils = @import("utils.zig"); const str = @import("str.zig"); +const sort = @import("sort.zig"); const UpdateMode = utils.UpdateMode; const mem = std.mem; const math = std.math; @@ -690,60 +691,10 @@ pub fn listDropAt( } } -fn partition( - source_ptr: [*]u8, - transform: Opaque, - wrapper: CompareFn, - element_width: usize, - low: isize, - high: isize, - copy: CopyFn, -) isize { - const pivot = source_ptr + (@as(usize, @intCast(high)) * element_width); - var i = (low - 1); // Index of smaller element and indicates the right position of pivot found so far - var j = low; - - while (j <= high - 1) : (j += 1) { - const current_elem = source_ptr + (@as(usize, @intCast(j)) * element_width); - - const ordering = wrapper(transform, current_elem, pivot); - const order = @as(utils.Ordering, @enumFromInt(ordering)); - - switch (order) { - utils.Ordering.LT => { - // the current element is smaller than the pivot; swap it - i += 1; - swapElements(source_ptr, element_width, @as(usize, @intCast(i)), @as(usize, @intCast(j)), copy); - }, - utils.Ordering.EQ, utils.Ordering.GT => {}, - } - } - swapElements(source_ptr, element_width, @as(usize, @intCast(i + 1)), @as(usize, @intCast(high)), copy); - return (i + 1); -} - -fn quicksort( - source_ptr: [*]u8, - transform: Opaque, - wrapper: CompareFn, - element_width: usize, - low: isize, - high: isize, - copy: CopyFn, -) void { - if (low < high) { - // partition index - const pi = partition(source_ptr, transform, wrapper, element_width, low, high, copy); - - _ = quicksort(source_ptr, transform, wrapper, element_width, low, pi - 1, copy); // before pi - _ = quicksort(source_ptr, transform, wrapper, element_width, pi + 1, high, copy); // after pi - } -} - pub fn listSortWith( input: RocList, - caller: CompareFn, - data: Opaque, + cmp: CompareFn, + cmp_data: Opaque, inc_n_data: IncN, data_is_owned: bool, alignment: u32, @@ -753,16 +704,13 @@ pub fn listSortWith( dec: Dec, copy: CopyFn, ) callconv(.C) RocList { + if (input.len() < 2) { + return input; + } var list = input.makeUnique(alignment, element_width, elements_refcounted, inc, dec); - if (data_is_owned) { - inc_n_data(data, list.len()); - } - if (list.bytes) |source_ptr| { - const low = 0; - const high: isize = @as(isize, @intCast(list.len())) - 1; - quicksort(source_ptr, data, caller, element_width, low, high, copy); + sort.quadsort(source_ptr, list.len(), cmp, cmp_data, data_is_owned, inc_n_data, element_width, copy); } return list; diff --git a/crates/compiler/builtins/bitcode/src/sort.zig b/crates/compiler/builtins/bitcode/src/sort.zig new file mode 100644 index 0000000000..10fcc122a3 --- /dev/null +++ b/crates/compiler/builtins/bitcode/src/sort.zig @@ -0,0 +1,293 @@ +const std = @import("std"); +const testing = std.testing; + +const utils = @import("utils.zig"); +const roc_panic = @import("panic.zig").panic_help; + +const Opaque = ?[*]u8; +const CompareFn = *const fn (Opaque, Opaque, Opaque) callconv(.C) u8; +const CopyFn = *const fn (Opaque, Opaque) callconv(.C) void; +const IncN = *const fn (?[*]u8, usize) callconv(.C) void; + +/// Any size larger than the max element buffer will be sorted indirectly via pointers. +const MAX_ELEMENT_BUFFER_SIZE: usize = 64; + +pub fn quadsort( + source_ptr: [*]u8, + len: usize, + cmp: CompareFn, + cmp_data: Opaque, + data_is_owned: bool, + inc_n_data: IncN, + element_width: usize, + copy: CopyFn, +) void { + if (element_width <= MAX_ELEMENT_BUFFER_SIZE) { + quadsort_direct(source_ptr, len, cmp, cmp_data, data_is_owned, inc_n_data, element_width, copy); + } else { + roc_panic("todo: fallback to an indirect pointer sort", 0); + } +} + +fn quadsort_direct( + source_ptr: [*]u8, + len: usize, + cmp: CompareFn, + cmp_data: Opaque, + data_is_owned: bool, + inc_n_data: IncN, + element_width: usize, + copy: CopyFn, +) void { + _ = inc_n_data; + _ = data_is_owned; + _ = cmp_data; + _ = len; + _ = copy; + _ = element_width; + _ = cmp; + _ = source_ptr; + roc_panic("todo: quadsort", 0); +} + +/// Merge two neighboring sorted 4 element arrays into swap. +inline fn parity_merge_four(ptr: [*]u8, swap: [*]u8, cmp_data: Opaque, cmp: CompareFn, element_width: usize, copy: CopyFn) void { + var left = ptr; + var right = ptr + (4 * element_width); + var swap_ptr = swap; + head_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + head_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + head_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + const lte = @as(utils.Ordering, @enumFromInt(cmp(cmp_data, left, right))) != utils.Ordering.GT; + var to_copy = if (lte) left else right; + copy(swap_ptr, to_copy); + + left = ptr + (3 * element_width); + right = ptr + (7 * element_width); + swap_ptr = swap + (7 * element_width); + tail_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + tail_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + tail_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + const gt = @as(utils.Ordering, @enumFromInt(cmp(cmp_data, left, right))) == utils.Ordering.GT; + to_copy = if (gt) left else right; + copy(swap_ptr, to_copy); +} + +/// Merge two neighboring sorted 2 element arrays into swap. +inline fn parity_merge_two(ptr: [*]u8, swap: [*]u8, cmp_data: Opaque, cmp: CompareFn, element_width: usize, copy: CopyFn) void { + var left = ptr; + var right = ptr + (2 * element_width); + var swap_ptr = swap; + head_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + const lte = @as(utils.Ordering, @enumFromInt(cmp(cmp_data, left, right))) != utils.Ordering.GT; + var to_copy = if (lte) left else right; + copy(swap_ptr, to_copy); + + left = ptr + element_width; + right = ptr + (3 * element_width); + swap_ptr = swap + (3 * element_width); + tail_branchless_merge(&swap_ptr, &left, &right, cmp_data, cmp, element_width, copy); + const gt = @as(utils.Ordering, @enumFromInt(cmp(cmp_data, left, right))) == utils.Ordering.GT; + to_copy = if (gt) left else right; + copy(swap_ptr, to_copy); +} + +/// Moves the smaller element from left and rigth to dest. +/// Will increment both dest and the smaller element ptr to their next index. +/// Inlining will remove the extra level of pointer indirection here. +/// It is just used to allow mutating the input pointers. +inline fn head_branchless_merge(dest: *[*]u8, left: *[*]u8, right: *[*]u8, cmp_data: Opaque, cmp: CompareFn, element_width: usize, copy: CopyFn) void { + // Note there is a much simpler version here: + // *ptd++ = cmp(ptl, ptr) <= 0 ? *ptl++ : *ptr++; + // That said, it is only used with gcc, so I assume it has optimization issues with llvm. + // Thus using the longer form. + const lte = @as(utils.Ordering, @enumFromInt(cmp(cmp_data, left.*, right.*))) != utils.Ordering.GT; + // TODO: double check this is branchless. + const x = if (lte) element_width else 0; + const not_x = if (lte) 0 else element_width; + copy(dest.*, left.*); + left.* += x; + copy((dest.* + x), right.*); + right.* += not_x; + dest.* += element_width; +} + +/// Moves the smaller element from left and rigth to dest. +/// Will decrement both dest and the smaller element ptr to their previous index. +/// Inlining will remove the extra level of pointer indirection here. +/// It is just used to allow mutating the input pointers. +inline fn tail_branchless_merge(dest: *[*]u8, left: *[*]u8, right: *[*]u8, cmp_data: Opaque, cmp: CompareFn, element_width: usize, copy: CopyFn) void { + // Note there is a much simpler version here: + // *tpd-- = cmp(tpl, tpr) > 0 ? *tpl-- : *tpr--; + // That said, it is only used with gcc, so I assume it has optimization issues with llvm. + // Thus using the longer form. + const lte = @as(utils.Ordering, @enumFromInt(cmp(cmp_data, left.*, right.*))) != utils.Ordering.GT; + // TODO: double check this is branchless. + const y = if (lte) element_width else 0; + const not_y = if (lte) 0 else element_width; + copy(dest.*, left.*); + left.* -= not_y; + dest.* -= element_width; + copy((dest.* + y), right.*); + right.* -= y; +} + +/// Swaps the element at ptr with the element after it if the element is greater than the next. +inline fn swap_branchless(ptr: [*]u8, swap: [*]u8, cmp_data: Opaque, cmp: CompareFn, element_width: usize, copy: CopyFn) void { + const gt = @as(utils.Ordering, @enumFromInt(cmp(cmp_data, ptr, ptr + element_width))) == utils.Ordering.GT; + // TODO: double check this is branchless. I would expect llvm to optimize this to be branchless. + // But based on reading some comments in quadsort, llvm seems to prefer branches very often. + const x = if (gt) element_width else 0; + const y = if (gt) 0 else element_width; + + copy(swap, ptr + y); + copy(ptr, ptr + x); + copy(ptr + element_width, swap); +} + +test "parity_merge_four" { + var arr = [8]i64{ 1, 2, 3, 4, 5, 6, 7, 8 }; + var swap = [8]i64{ 0, 0, 0, 0, 0, 0, 0, 0 }; + + var arr_ptr = @as([*]u8, @ptrCast(&arr[0])); + var swap_ptr = @as([*]u8, @ptrCast(&swap[0])); + + parity_merge_four(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [8]i64{ 1, 2, 3, 4, 5, 6, 7, 8 }); + + arr = [8]i64{ 5, 6, 7, 8, 1, 2, 3, 4 }; + swap = [8]i64{ 0, 0, 0, 0, 0, 0, 0, 0 }; + + parity_merge_four(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [8]i64{ 1, 2, 3, 4, 5, 6, 7, 8 }); + + arr = [8]i64{ 1, 3, 5, 7, 2, 4, 6, 8 }; + swap = [8]i64{ 0, 0, 0, 0, 0, 0, 0, 0 }; + + parity_merge_four(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [8]i64{ 1, 2, 3, 4, 5, 6, 7, 8 }); +} + +test "parity_merge_two" { + var arr = [4]i64{ 1, 2, 3, 4 }; + var swap = [4]i64{ 0, 0, 0, 0 }; + + var arr_ptr = @as([*]u8, @ptrCast(&arr[0])); + var swap_ptr = @as([*]u8, @ptrCast(&swap[0])); + + parity_merge_two(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [4]i64{ 1, 2, 3, 4 }); + + arr = [4]i64{ 1, 3, 2, 4 }; + swap = [4]i64{ 0, 0, 0, 0 }; + + parity_merge_two(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [4]i64{ 1, 2, 3, 4 }); + + arr = [4]i64{ 3, 4, 1, 2 }; + swap = [4]i64{ 0, 0, 0, 0 }; + + parity_merge_two(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [4]i64{ 1, 2, 3, 4 }); + + arr = [4]i64{ 2, 4, 1, 3 }; + swap = [4]i64{ 0, 0, 0, 0 }; + + parity_merge_two(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [4]i64{ 1, 2, 3, 4 }); + + arr = [4]i64{ 1, 4, 2, 3 }; + swap = [4]i64{ 0, 0, 0, 0 }; + + parity_merge_two(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(swap, [4]i64{ 1, 2, 3, 4 }); +} + +test "head_merge" { + var dest = [6]i64{ 0, 0, 0, 0, 0, 0 }; + var left = [4]i64{ 1, 7, 10, 22 }; + var right = [4]i64{ 2, 2, 8, 22 }; + var dest_ptr = @as([*]u8, @ptrCast(&dest[0])); + var left_ptr = @as([*]u8, @ptrCast(&left[0])); + var right_ptr = @as([*]u8, @ptrCast(&right[0])); + + head_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + head_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + head_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + head_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + head_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + head_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(dest, [6]i64{ 1, 2, 2, 7, 8, 10 }); +} + +test "tail_merge" { + var dest = [6]i64{ 0, 0, 0, 0, 0, 0 }; + var left = [4]i64{ -22, 1, 7, 10 }; + var right = [4]i64{ -22, 2, 2, 8 }; + var dest_ptr = @as([*]u8, @ptrCast(&dest[dest.len - 1])); + var left_ptr = @as([*]u8, @ptrCast(&left[left.len - 1])); + var right_ptr = @as([*]u8, @ptrCast(&right[right.len - 1])); + + tail_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + tail_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + tail_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + tail_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + tail_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + tail_branchless_merge(&dest_ptr, &left_ptr, &right_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(dest, [6]i64{ 1, 2, 2, 7, 8, 10 }); +} + +test "swap" { + var arr = [2]i64{ 10, 20 }; + var arr_ptr = @as([*]u8, @ptrCast(&arr[0])); + var swap: i64 = 0; + var swap_ptr = @as([*]u8, @ptrCast(&swap)); + + swap_branchless(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(arr[0], 10); + try testing.expectEqual(arr[1], 20); + + arr[0] = 77; + arr[1] = -12; + + swap_branchless(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(arr[0], -12); + try testing.expectEqual(arr[1], 77); + + arr[0] = -22; + arr[1] = -22; + + swap_branchless(arr_ptr, swap_ptr, null, &test_i64_compare, @sizeOf(i64), &test_i64_copy); + + try testing.expectEqual(arr[0], -22); + try testing.expectEqual(arr[1], -22); +} + +fn test_i64_compare(_: Opaque, a_ptr: Opaque, b_ptr: Opaque) callconv(.C) u8 { + const a = @as(*i64, @alignCast(@ptrCast(a_ptr))).*; + const b = @as(*i64, @alignCast(@ptrCast(b_ptr))).*; + + const gt = @as(u8, @intFromBool(a > b)); + const lt = @as(u8, @intFromBool(a < b)); + + // Eq = 0 + // GT = 1 + // LT = 2 + return lt + lt + gt; +} + +fn test_i64_copy(dst_ptr: Opaque, src_ptr: Opaque) callconv(.C) void { + @as(*i64, @alignCast(@ptrCast(dst_ptr))).* = @as(*i64, @alignCast(@ptrCast(src_ptr))).*; +}