Refactor RocDec.toStr

This commit is contained in:
Jared Ramirez 2021-06-18 12:50:15 -07:00
parent b6f0a3f693
commit 88aa8318f1

View file

@ -7,9 +7,10 @@ const RocStr = str.RocStr;
pub const RocDec = struct {
num: i128,
pub const decimal_places: comptime u32 = 18;
pub const whole_number_places: comptime u32 = 21;
const max_digits: comptime u32 = decimal_places + whole_number_places;
pub const decimal_places: comptime u5 = 18;
pub const whole_number_places: comptime u5 = 21;
const max_digits: comptime u6 = 39;
const leading_zeros: comptime [17]u8 = "00000000000000000".*;
pub const min: comptime RocDec = .{ .num = math.minInt(i128) };
pub const max: comptime RocDec = .{ .num = math.maxInt(i128) };
@ -21,7 +22,6 @@ pub const RocDec = struct {
return .{ .num = num * one_point_zero_i128 };
}
// TODO: Refactor this to use https://ziglang.org/documentation/master/#ctz
pub fn fromStr(roc_str: RocStr) ?RocDec {
if (roc_str.isEmpty()) {
return null;
@ -104,15 +104,16 @@ pub const RocDec = struct {
return dec;
}
// TODO: Replace this with https://github.com/rtfeldman/roc/pull/1365/files#r643580738
fn isDigit(c: u8) bool {
return switch (c) {
'0'...'9' => true,
else => false,
};
return (c -% 48) <= 9;
}
pub fn toStr(self: RocDec) ?RocStr {
// Special case
if (self.num == 0) {
return RocStr.init("0.0", 3);
}
// Check if this Dec is negative, and if so convert to positive
// We will handle adding the '-' later
const is_negative = self.num < 0;
@ -122,80 +123,70 @@ pub const RocDec = struct {
// Format the backing i128 into an array of digits (u8s)
var digit_bytes: [max_digits + 1]u8 = undefined;
var num_digits_formatted = std.fmt.formatIntBuf(digit_bytes[0..], num, 10, false, .{});
var num_digits = std.fmt.formatIntBuf(digit_bytes[0..], num, 10, false, .{});
// If self < 1, then pad digit_bytes with '0' to be at least 18 digits
if (num_digits_formatted < decimal_places) {
var diff = decimal_places - num_digits_formatted;
var padded_digit_bytes: [max_digits + 1]u8 = undefined;
var index: usize = 0;
while (index < decimal_places) {
if (index < diff) {
padded_digit_bytes[index] = '0';
} else {
padded_digit_bytes[index] = digit_bytes[index - diff];
}
index += 1;
}
num_digits_formatted = num_digits_formatted + diff;
digit_bytes = padded_digit_bytes;
}
// Get the slice of the part before the decimal
// If this is empty, then hardcode a '0'
var before_digits_num_raw: usize = undefined;
var before_digits_slice: []const u8 = undefined;
if (num_digits_formatted > decimal_places) {
before_digits_num_raw = num_digits_formatted - decimal_places;
before_digits_slice = digit_bytes[0..before_digits_num_raw];
// Get the slice for before the decimal point
var before_digits_slice_t: []const u8 = undefined;
var before_digits_offset: usize = 0;
var before_digits_adjust: u6 = 0;
if (num_digits > decimal_places) {
before_digits_offset = num_digits - decimal_places;
before_digits_slice_t = digit_bytes[0..before_digits_offset];
} else {
before_digits_num_raw = 0;
before_digits_slice = "0";
before_digits_adjust = @intCast(u6, math.absInt(@intCast(i7, num_digits) - decimal_places) catch {
std.debug.panic("TODO runtime exception for overflow when getting abs", .{});
});
before_digits_slice_t = "0";
}
// Figure out the index where the trailing zeros start
var index = decimal_places - 1;
var trim_index: ?usize = null;
// Figure out how many trailing zeros there are
// I tried to use https://ziglang.org/documentation/0.8.0/#ctz and it mostly worked,
// but was giving seemingly incorrect values for certain numbers. So instead we use
// a while loop and figure it out that way.
//
// const trailing_zeros = @ctz(u6, num);
//
var trailing_zeros: u6 = 0;
var index = decimal_places - 1 - before_digits_adjust;
var is_consecutive_zero = true;
while (index != 0) {
var digit = digit_bytes[before_digits_num_raw + index];
// 48 => '0', 170 => ''
if ((digit == 48 or digit == 170) and is_consecutive_zero) {
trim_index = index;
var digit = digit_bytes[before_digits_offset + index];
if (digit == '0' and is_consecutive_zero) {
trailing_zeros += 1;
} else {
is_consecutive_zero = false;
}
index -= 1;
}
// Get the slice of the part afterthe decimal
var after_digits_slice: []const u8 = undefined;
after_digits_slice = digit_bytes[before_digits_num_raw..(before_digits_num_raw + if (trim_index) |i| i else decimal_places)];
// Figure out if we need to prepend any zeros to the after decimal point
// For example, for the number 1.00023 we need to prepend 3 zeros after the decimal point
const after_zeros_num = if (num_digits < decimal_places) decimal_places - num_digits else 0;
const after_zeros_slice: []const u8 = leading_zeros[0..after_zeros_num];
// Make the RocStr
var sign_len: usize = if (is_negative) 1 else 0;
var dot_len: usize = 1;
var str_len: usize = sign_len + before_digits_slice.len + dot_len + after_digits_slice.len;
// TODO: Ideally we'd use [str_len]u8 here, but Zig gives an error if we do that.
// [max_digits + 2]u8 here to account for '.' and '-', aka the max possible length of the string
var str_bytes: [max_digits + 2]u8 = undefined;
// Join the whole number slice & the decimal slice together
// The format template arg in bufPrint is `comptime`, so we have to repeate the whole statement in each branch
if (is_negative) {
_ = std.fmt.bufPrint(str_bytes[0 .. str_len + 1], "-{s}.{s}", .{ before_digits_slice, after_digits_slice }) catch {
std.debug.panic("TODO runtime exception failing to print slices", .{});
};
// Get the slice for after the decimal point
var after_digits_slice_t: []const u8 = undefined;
if ((num_digits - before_digits_offset) == trailing_zeros) {
after_digits_slice_t = "0";
} else {
_ = std.fmt.bufPrint(str_bytes[0 .. str_len + 1], "{s}.{s}", .{ before_digits_slice, after_digits_slice }) catch {
std.debug.panic("TODO runtime exception failing to print slices", .{});
};
after_digits_slice_t = digit_bytes[before_digits_offset .. num_digits - trailing_zeros];
}
return RocStr.init(&str_bytes, str_len);
// Get the slice for the sign
const sign_slice: []const u8 = if (is_negative) "-" else leading_zeros[0..0];
// Hardcode adding a `1` for the '.' character
const str_len_t: usize = sign_slice.len + before_digits_slice_t.len + 1 + after_zeros_slice.len + after_digits_slice_t.len;
// Join the slices together
// We do `max_digits + 2` here becuase we need to account for a possible sign ('-') and the dot ('.').
// Ideally, we'd use str_len_t here
var str_bytes_t: [max_digits + 2]u8 = undefined;
_ = std.fmt.bufPrint(str_bytes_t[0..str_len_t], "{s}{s}.{s}{s}", .{ sign_slice, before_digits_slice_t, after_zeros_slice, after_digits_slice_t }) catch {
std.debug.panic("TODO runtime exception failing to print slices", .{});
};
return RocStr.init(&str_bytes_t, str_len_t);
}
pub fn negate(self: RocDec) ?RocDec {
@ -648,11 +639,6 @@ fn div_u256_by_u128(numer: U256, denom: u128) U256 {
return .{ .hi = hi, .lo = lo };
}
fn num_of_trailing_zeros(num: u128) u32 {
const trailing: u8 = @ctz(u128, num);
return @intCast(u32, trailing);
}
const testing = std.testing;
const expectEqual = testing.expectEqual;
const expectError = testing.expectError;