have zig generate more efficient copy functions

This commit is contained in:
Brendan Hansknecht 2024-07-21 14:56:21 -07:00
parent db9a5fd261
commit fec875d045
No known key found for this signature in database
GPG key ID: 0EA784685083E75B

View file

@ -6,10 +6,12 @@ const mem = std.mem;
const math = std.math;
const expect = std.testing.expect;
const expectEqual = std.testing.expectEqual;
const EqFn = *const fn (?[*]u8, ?[*]u8) callconv(.C) bool;
const CompareFn = *const fn (?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) u8;
const Opaque = ?[*]u8;
const EqFn = *const fn (Opaque, Opaque) callconv(.C) bool;
const CompareFn = *const fn (Opaque, Opaque, Opaque) callconv(.C) u8;
const CopyFn = *const fn (Opaque, Opaque, usize) void;
const Inc = *const fn (?[*]u8) callconv(.C) void;
const IncN = *const fn (?[*]u8, usize) callconv(.C) void;
@ -433,7 +435,7 @@ pub fn listAppendUnsafe(
if (output.bytes) |bytes| {
if (element) |source| {
const target = bytes + old_length * element_width;
@memcpy(target[0..element_width], source[0..element_width]);
copy_element_fn(element_width)(target, source, element_width);
}
}
@ -468,20 +470,14 @@ pub fn listPrepend(
// can't use one memcpy here because source and target overlap
if (with_capacity.bytes) |target| {
var i: usize = old_length;
while (i > 0) {
i -= 1;
// move the ith element to the (i + 1)th position
const to = target + (i + 1) * element_width;
const from = target + i * element_width;
@memcpy(to[0..element_width], from[0..element_width]);
}
const from = target;
const to = target + element_width;
const size = element_width * old_length;
std.mem.copyBackwards(u8, to[0..size], from[0..size]);
// finally copy in the new first element
if (element) |source| {
@memcpy(target[0..element_width], source[0..element_width]);
copy_element_fn(element_width)(target, source, element_width);
}
}
@ -519,7 +515,8 @@ pub fn listSwap(
const source_ptr = @as([*]u8, @ptrCast(newList.bytes));
swapElements(source_ptr, element_width, @as(usize,
const copy_fn_ptr = copy_element_fn(element_width);
swapElements(copy_fn_ptr, source_ptr, element_width, @as(usize,
// We already verified that both indices are less than the stored list length,
// which is usize, so casting them to usize will definitely be lossless.
@intCast(index_1)), @as(usize, @intCast(index_2)));
@ -653,12 +650,9 @@ pub fn listDropAt(
if (list.isUnique()) {
var i = drop_index;
while (i < size - 1) : (i += 1) {
const copy_target = source_ptr + i * element_width;
const copy_source = copy_target + element_width;
@memcpy(copy_target[0..element_width], copy_source[0..element_width]);
}
const copy_target = source_ptr;
const copy_source = copy_target + element_width;
std.mem.copyForwards(u8, copy_target[i..size], copy_source[i..size]);
var new_list = list;
@ -693,7 +687,7 @@ pub fn listDropAt(
}
}
fn partition(source_ptr: [*]u8, transform: Opaque, wrapper: CompareFn, element_width: usize, low: isize, high: isize) isize {
fn partition(copy_fn_ptr: CopyFn, source_ptr: [*]u8, transform: Opaque, wrapper: CompareFn, element_width: usize, low: isize, high: isize) isize {
const pivot = source_ptr + (@as(usize, @intCast(high)) * element_width);
var i = (low - 1); // Index of smaller element and indicates the right position of pivot found so far
var j = low;
@ -708,22 +702,22 @@ fn partition(source_ptr: [*]u8, transform: Opaque, wrapper: CompareFn, element_w
utils.Ordering.LT => {
// the current element is smaller than the pivot; swap it
i += 1;
swapElements(source_ptr, element_width, @as(usize, @intCast(i)), @as(usize, @intCast(j)));
swapElements(copy_fn_ptr, source_ptr, element_width, @as(usize, @intCast(i)), @as(usize, @intCast(j)));
},
utils.Ordering.EQ, utils.Ordering.GT => {},
}
}
swapElements(source_ptr, element_width, @as(usize, @intCast(i + 1)), @as(usize, @intCast(high)));
swapElements(copy_fn_ptr, source_ptr, element_width, @as(usize, @intCast(i + 1)), @as(usize, @intCast(high)));
return (i + 1);
}
fn quicksort(source_ptr: [*]u8, transform: Opaque, wrapper: CompareFn, element_width: usize, low: isize, high: isize) void {
fn quicksort(copy_fn_ptr: CopyFn, source_ptr: [*]u8, transform: Opaque, wrapper: CompareFn, element_width: usize, low: isize, high: isize) void {
if (low < high) {
// partition index
const pi = partition(source_ptr, transform, wrapper, element_width, low, high);
const pi = partition(copy_fn_ptr, source_ptr, transform, wrapper, element_width, low, high);
_ = quicksort(source_ptr, transform, wrapper, element_width, low, pi - 1); // before pi
_ = quicksort(source_ptr, transform, wrapper, element_width, pi + 1, high); // after pi
_ = quicksort(copy_fn_ptr, source_ptr, transform, wrapper, element_width, low, pi - 1); // before pi
_ = quicksort(copy_fn_ptr, source_ptr, transform, wrapper, element_width, pi + 1, high); // after pi
}
}
@ -748,7 +742,7 @@ pub fn listSortWith(
if (list.bytes) |source_ptr| {
const low = 0;
const high: isize = @as(isize, @intCast(list.len())) - 1;
quicksort(source_ptr, data, caller, element_width, low, high);
quicksort(copy_element_fn(element_width), source_ptr, data, caller, element_width, low, high);
}
return list;
@ -756,29 +750,35 @@ pub fn listSortWith(
// SWAP ELEMENTS
inline fn swapHelp(width: usize, temporary: [*]u8, ptr1: [*]u8, ptr2: [*]u8) void {
@memcpy(temporary[0..width], ptr1[0..width]);
@memcpy(ptr1[0..width], ptr2[0..width]);
@memcpy(ptr2[0..width], temporary[0..width]);
}
fn swap(width_initial: usize, p1: [*]u8, p2: [*]u8) void {
const threshold: usize = 64;
var width = width_initial;
var ptr1 = p1;
var ptr2 = p2;
fn swap(copy_fn_ptr: CopyFn, element_width: usize, p1: [*]u8, p2: [*]u8) void {
const threshold: usize = @sizeOf(u256);
var buffer_actual: [threshold]u8 = undefined;
const buffer: [*]u8 = buffer_actual[0..];
if (element_width <= threshold) {
copy_fn_ptr(buffer, p1, element_width);
copy_fn_ptr(p1, p2, element_width);
copy_fn_ptr(p2, buffer, element_width);
return;
}
var width = element_width;
var ptr1 = p1;
var ptr2 = p2;
const copy_buffer_ptr = comptime copy_element_fn(threshold);
while (true) {
if (width < threshold) {
swapHelp(width, buffer, ptr1, ptr2);
@memcpy(buffer[0..width], ptr1[0..width]);
@memcpy(ptr1[0..width], ptr2[0..width]);
@memcpy(ptr2[0..width], buffer[0..width]);
return;
} else {
swapHelp(threshold, buffer, ptr1, ptr2);
copy_buffer_ptr(buffer, ptr1, threshold);
copy_buffer_ptr(ptr1, ptr2, threshold);
copy_buffer_ptr(ptr2, buffer, threshold);
ptr1 += threshold;
ptr2 += threshold;
@ -788,11 +788,11 @@ fn swap(width_initial: usize, p1: [*]u8, p2: [*]u8) void {
}
}
fn swapElements(source_ptr: [*]u8, element_width: usize, index_1: usize, index_2: usize) void {
fn swapElements(copy_fn_ptr: CopyFn, source_ptr: [*]u8, element_width: usize, index_1: usize, index_2: usize) void {
const element_at_i = source_ptr + (index_1 * element_width);
const element_at_j = source_ptr + (index_2 * element_width);
return swap(element_width, element_at_i, element_at_j);
return swap(copy_fn_ptr, element_width, element_at_i, element_at_j);
}
pub fn listConcat(
@ -952,11 +952,12 @@ inline fn listReplaceInPlaceHelp(
// the element we will replace
var element_at_index = (list.bytes orelse unreachable) + (index * element_width);
const copy_fn_ptr = copy_element_fn(element_width);
// copy out the old element
@memcpy((out_element orelse unreachable)[0..element_width], element_at_index[0..element_width]);
copy_fn_ptr((out_element orelse unreachable), element_at_index, element_width);
// copy in the new element
@memcpy(element_at_index[0..element_width], (element orelse unreachable)[0..element_width]);
copy_fn_ptr(element_at_index, (element orelse unreachable), element_width);
return list;
}
@ -1029,6 +1030,76 @@ pub fn listConcatUtf8(
}
}
fn copy_element_fn(element_width: usize) CopyFn {
switch (element_width) {
@sizeOf(u8) => {
return memcpy_T(u8);
},
@sizeOf(u16) => {
return memcpy_T(u16);
},
@sizeOf(u32) => {
return memcpy_T(u32);
},
@sizeOf(u64) => {
return memcpy_T(u64);
},
@sizeOf(u128) => {
return memcpy_T(u128);
},
@sizeOf(u256) => {
return memcpy_T(u256);
},
else => {
return &memcpy_opaque;
},
}
}
fn memcpy_opaque(dst: Opaque, src: Opaque, element_width: usize) void {
@memcpy(@as([*]u8, @ptrCast(dst))[0..element_width], @as([*]u8, @ptrCast(src))[0..element_width]);
}
fn memcpy_T(comptime T: type) CopyFn {
return &(struct {
element_width: usize,
pub fn memcpy(dst: Opaque, src: Opaque, _: usize) void {
@as(*T, @alignCast(@ptrCast(dst))).* = @as(*T, @alignCast(@ptrCast(src))).*;
}
}.memcpy);
}
test "gen memcpy fn" {
{
const element_width = @sizeOf(u8);
const copy_fn_ptr = copy_element_fn(element_width);
var x: u8 = 7;
var y: u8 = 12;
copy_fn_ptr(@ptrCast(&x), @ptrCast(&y), element_width);
try expectEqual(y, x);
}
{
const element_width = @sizeOf(u32);
const copy_fn_ptr = copy_element_fn(element_width);
var x: u32 = 7;
var y: u32 = 0xDEAD_BEEF;
copy_fn_ptr(@ptrCast(&x), @ptrCast(&y), element_width);
try expectEqual(y, x);
}
{
const element_width = @sizeOf(u512);
const copy_fn_ptr = copy_element_fn(element_width);
var x: u512 = 7;
var y: u512 = 1 << 500;
copy_fn_ptr(@ptrCast(&x), @ptrCast(&y), element_width);
try expectEqual(y, x);
}
}
test "listConcatUtf8" {
const list = RocList.fromSlice(u8, &[_]u8{ 1, 2, 3, 4 }, false);
defer list.decref(1, 1, false, &rcNone);