Merge pull request #6460 from roc-lang/dec-div-floor

`floor`, `ceiling` and `round` for `Dec`
This commit is contained in:
Brendan Hansknecht 2024-01-30 16:38:12 -08:00 committed by GitHub
commit e7be9d435d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 315 additions and 69 deletions

View file

@ -24,6 +24,9 @@ pub const RocDec = extern struct {
pub const one_point_zero_i128: i128 = math.pow(i128, 10, RocDec.decimal_places);
pub const one_point_zero: RocDec = .{ .num = one_point_zero_i128 };
pub const two_point_zero: RocDec = RocDec.add(RocDec.one_point_zero, RocDec.one_point_zero);
pub const zero_point_five: RocDec = RocDec.div(RocDec.one_point_zero, RocDec.two_point_zero);
pub fn fromU64(num: u64) RocDec {
return .{ .num = num * one_point_zero_i128 };
}
@ -340,6 +343,54 @@ pub const RocDec = extern struct {
}
}
fn trunc(self: RocDec) RocDec {
return RocDec.sub(self, self.fract());
}
fn fract(self: RocDec) RocDec {
const sign = std.math.sign(self.num);
const digits = @mod(sign * self.num, RocDec.one_point_zero.num);
return RocDec{ .num = sign * digits };
}
// Returns the nearest integer to self. If a value is half-way between two integers, round away from 0.0.
fn round(arg1: RocDec) RocDec {
// this rounds towards zero
const tmp = arg1.trunc();
const sign = std.math.sign(arg1.num);
const abs_fract = sign * arg1.fract().num;
if (abs_fract >= RocDec.zero_point_five.num) {
return RocDec.add(tmp, RocDec{ .num = sign * RocDec.one_point_zero.num });
} else {
return tmp;
}
}
// Returns the largest integer less than or equal to itself
fn floor(arg1: RocDec) RocDec {
const tmp = arg1.trunc();
if (arg1.num < 0 and arg1.fract().num != 0) {
return RocDec.sub(tmp, RocDec.one_point_zero);
} else {
return tmp;
}
}
// Returns the smallest integer greater than or equal to itself
fn ceiling(arg1: RocDec) RocDec {
const tmp = arg1.trunc();
if (arg1.num > 0 and arg1.fract().num != 0) {
return RocDec.add(tmp, RocDec.one_point_zero);
} else {
return tmp;
}
}
pub fn mul(self: RocDec, other: RocDec) RocDec {
const answer = RocDec.mulWithOverflow(self, other);
@ -1195,6 +1246,118 @@ test "log: 1" {
try expectEqual(RocDec.fromU64(0), RocDec.log(RocDec.fromU64(1)));
}
test "fract: 0" {
var roc_str = RocStr.init("0", 1);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 0 }, dec.fract());
}
test "fract: 1" {
var roc_str = RocStr.init("1", 1);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 0 }, dec.fract());
}
test "fract: 123.45" {
var roc_str = RocStr.init("123.45", 6);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 450000000000000000 }, dec.fract());
}
test "fract: -123.45" {
var roc_str = RocStr.init("-123.45", 7);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = -450000000000000000 }, dec.fract());
}
test "fract: .45" {
var roc_str = RocStr.init(".45", 3);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 450000000000000000 }, dec.fract());
}
test "fract: -0.00045" {
const dec: RocDec = .{ .num = -450000000000000 };
const res = dec.fract();
try expectEqual(dec.num, res.num);
}
test "trunc: 0" {
var roc_str = RocStr.init("0", 1);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 0 }, dec.trunc());
}
test "trunc: 1" {
var roc_str = RocStr.init("1", 1);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec.one_point_zero, dec.trunc());
}
test "trunc: 123.45" {
var roc_str = RocStr.init("123.45", 6);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 123000000000000000000 }, dec.trunc());
}
test "trunc: -123.45" {
var roc_str = RocStr.init("-123.45", 7);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = -123000000000000000000 }, dec.trunc());
}
test "trunc: .45" {
var roc_str = RocStr.init(".45", 3);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 0 }, dec.trunc());
}
test "trunc: -0.00045" {
const dec: RocDec = .{ .num = -450000000000000 };
const res = dec.trunc();
try expectEqual(RocDec{ .num = 0 }, res);
}
test "round: 123.45" {
var roc_str = RocStr.init("123.45", 6);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = 123000000000000000000 }, dec.round());
}
test "round: -123.45" {
var roc_str = RocStr.init("-123.45", 7);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = -123000000000000000000 }, dec.round());
}
test "round: 0.5" {
var roc_str = RocStr.init("0.5", 3);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec.one_point_zero, dec.round());
}
test "round: -0.5" {
var roc_str = RocStr.init("-0.5", 4);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec{ .num = -1000000000000000000 }, dec.round());
}
// exports
pub fn fromStr(arg: RocStr) callconv(.C) num_.NumParseResult(i128) {
@ -1342,3 +1505,30 @@ pub fn mulOrPanicC(arg1: RocDec, arg2: RocDec) callconv(.C) RocDec {
pub fn mulSaturatedC(arg1: RocDec, arg2: RocDec) callconv(.C) RocDec {
return @call(.always_inline, RocDec.mulSaturated, .{ arg1, arg2 });
}
pub fn exportRound(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: RocDec) callconv(.C) T {
return @as(T, @intCast(@divFloor(input.round().num, RocDec.one_point_zero_i128)));
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}
pub fn exportFloor(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: RocDec) callconv(.C) T {
return @as(T, @intCast(@divFloor(input.floor().num, RocDec.one_point_zero_i128)));
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}
pub fn exportCeiling(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: RocDec) callconv(.C) T {
return @as(T, @intCast(@divFloor(input.ceiling().num, RocDec.one_point_zero_i128)));
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}

View file

@ -52,6 +52,10 @@ comptime {
inline for (INTEGERS) |T| {
dec.exportFromInt(T, ROC_BUILTINS ++ ".dec.from_int.");
dec.exportRound(T, ROC_BUILTINS ++ ".dec.round.");
dec.exportFloor(T, ROC_BUILTINS ++ ".dec.floor.");
dec.exportCeiling(T, ROC_BUILTINS ++ ".dec.ceiling.");
}
}

View file

@ -420,6 +420,9 @@ pub const DEC_SUB_WITH_OVERFLOW: &str = "roc_builtins.dec.sub_with_overflow";
pub const DEC_TAN: &str = "roc_builtins.dec.tan";
pub const DEC_TO_I128: &str = "roc_builtins.dec.to_i128";
pub const DEC_TO_STR: &str = "roc_builtins.dec.to_str";
pub const DEC_ROUND: IntrinsicName = int_intrinsic!("roc_builtins.dec.round");
pub const DEC_FLOOR: IntrinsicName = int_intrinsic!("roc_builtins.dec.floor");
pub const DEC_CEILING: IntrinsicName = int_intrinsic!("roc_builtins.dec.ceiling");
pub const UTILS_DBG_IMPL: &str = "roc_builtins.utils.dbg_impl";
pub const UTILS_TEST_PANIC: &str = "roc_builtins.utils.test_panic";

View file

@ -1106,30 +1106,36 @@ trait Backend<'a> {
self.build_fn_call(sym, intrinsic.to_string(), args, arg_layouts, ret_layout)
}
LowLevel::NumRound => {
let repr = self.interner().get_repr(*ret_layout);
let LayoutRepr::Builtin(Builtin::Int(int_width)) = repr else {
unreachable!("invalid return layout for NumRound")
};
let intrinsic = match arg_layouts[0] {
Layout::F32 => &bitcode::NUM_ROUND_F32[int_width],
Layout::F64 => &bitcode::NUM_ROUND_F64[int_width],
Layout::DEC => &bitcode::DEC_ROUND[int_width],
_ => unreachable!("invalid layout for NumRound"),
};
self.build_fn_call(sym, intrinsic.to_string(), args, arg_layouts, ret_layout)
}
LowLevel::NumFloor => {
let repr = self.interner().get_repr(*ret_layout);
let LayoutRepr::Builtin(Builtin::Int(int_width)) = repr else {
unreachable!("invalid return layout for NumFloor")
};
match arg_layouts[0] {
Layout::F32 => self.build_fn_call(
sym,
bitcode::NUM_FLOOR_F32[int_width].to_string(),
args,
arg_layouts,
ret_layout,
),
Layout::F64 => self.build_fn_call(
sym,
bitcode::NUM_FLOOR_F64[int_width].to_string(),
args,
arg_layouts,
ret_layout,
),
Layout::DEC => todo!("NumFloor for decimals"),
let intrinsic = match arg_layouts[0] {
Layout::F32 => &bitcode::NUM_FLOOR_F32[int_width],
Layout::F64 => &bitcode::NUM_FLOOR_F64[int_width],
Layout::DEC => &bitcode::DEC_FLOOR[int_width],
_ => unreachable!("invalid layout for NumFloor"),
}
};
self.build_fn_call(sym, intrinsic.to_string(), args, arg_layouts, ret_layout)
}
LowLevel::NumCeiling => {
@ -1138,24 +1144,14 @@ trait Backend<'a> {
unreachable!("invalid return layout for NumCeiling")
};
match arg_layouts[0] {
Layout::F32 => self.build_fn_call(
sym,
bitcode::NUM_CEILING_F32[int_width].to_string(),
args,
arg_layouts,
ret_layout,
),
Layout::F64 => self.build_fn_call(
sym,
bitcode::NUM_CEILING_F64[int_width].to_string(),
args,
arg_layouts,
ret_layout,
),
Layout::DEC => todo!("NumCeiling for decimals"),
let intrinsic = match arg_layouts[0] {
Layout::F32 => &bitcode::NUM_CEILING_F32[int_width],
Layout::F64 => &bitcode::NUM_CEILING_F64[int_width],
Layout::DEC => &bitcode::DEC_CEILING[int_width],
_ => unreachable!("invalid layout for NumCeiling"),
}
};
self.build_fn_call(sym, intrinsic.to_string(), args, arg_layouts, ret_layout)
}
LowLevel::NumSub => {
@ -1494,13 +1490,6 @@ trait Backend<'a> {
self.build_fn_call(sym, intrinsic.to_string(), args, arg_layouts, ret_layout)
}
LowLevel::NumRound => self.build_fn_call(
sym,
bitcode::NUM_ROUND_F64[IntWidth::I64].to_string(),
args,
arg_layouts,
ret_layout,
),
LowLevel::ListLen => {
debug_assert_eq!(
1,

View file

@ -2189,11 +2189,13 @@ fn build_dec_unary_op<'a, 'ctx>(
_layout_interner: &STLayoutInterner<'a>,
_parent: FunctionValue<'ctx>,
arg: BasicValueEnum<'ctx>,
_return_layout: InLayout<'a>,
return_layout: InLayout<'a>,
op: LowLevel,
) -> BasicValueEnum<'ctx> {
use roc_module::low_level::LowLevel::*;
let int_width = || return_layout.to_int_width();
match op {
NumAbs => dec_unary_op(env, bitcode::DEC_ABS, arg),
NumAcos => dec_unary_op(env, bitcode::DEC_ACOS, arg),
@ -2203,6 +2205,10 @@ fn build_dec_unary_op<'a, 'ctx>(
NumSin => dec_unary_op(env, bitcode::DEC_SIN, arg),
NumTan => dec_unary_op(env, bitcode::DEC_TAN, arg),
NumRound => dec_unary_op(env, &bitcode::DEC_ROUND[int_width()], arg),
NumFloor => dec_unary_op(env, &bitcode::DEC_FLOOR[int_width()], arg),
NumCeiling => dec_unary_op(env, &bitcode::DEC_CEILING[int_width()], arg),
_ => {
unreachable!("Unrecognized dec unary operation: {:?}", op);
}
@ -2684,42 +2690,39 @@ fn build_float_unary_op<'a, 'ctx>(
LayoutRepr::Builtin(Builtin::Int(int_width)) => int_width,
_ => internal_error!("Ceiling return layout is not int: {:?}", layout),
};
match float_width {
FloatWidth::F32 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_CEILING_F32[int_width])
}
FloatWidth::F64 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_CEILING_F64[int_width])
}
}
let intrinsic = match float_width {
FloatWidth::F32 => &bitcode::NUM_CEILING_F32[int_width],
FloatWidth::F64 => &bitcode::NUM_CEILING_F64[int_width],
};
call_bitcode_fn(env, &[arg.into()], intrinsic)
}
NumFloor => {
let int_width = match layout_interner.get_repr(layout) {
LayoutRepr::Builtin(Builtin::Int(int_width)) => int_width,
_ => internal_error!("Floor return layout is not int: {:?}", layout),
};
match float_width {
FloatWidth::F32 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_FLOOR_F32[int_width])
}
FloatWidth::F64 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_FLOOR_F64[int_width])
}
}
let intrinsic = match float_width {
FloatWidth::F32 => &bitcode::NUM_FLOOR_F32[int_width],
FloatWidth::F64 => &bitcode::NUM_FLOOR_F64[int_width],
};
call_bitcode_fn(env, &[arg.into()], intrinsic)
}
NumRound => {
let int_width = match layout_interner.get_repr(layout) {
LayoutRepr::Builtin(Builtin::Int(int_width)) => int_width,
_ => internal_error!("Round return layout is not int: {:?}", layout),
};
match float_width {
FloatWidth::F32 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_ROUND_F32[int_width])
}
FloatWidth::F64 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_ROUND_F64[int_width])
}
}
let intrinsic = match float_width {
FloatWidth::F32 => &bitcode::NUM_ROUND_F32[int_width],
FloatWidth::F64 => &bitcode::NUM_ROUND_F64[int_width],
};
call_bitcode_fn(env, &[arg.into()], intrinsic)
}
NumIsNan => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_IS_NAN[float_width]),
NumIsInfinite => {

View file

@ -1636,6 +1636,7 @@ impl<'a> LowLevelCall<'a> {
match arg_type {
F32 => self.load_args_and_call_zig(backend, &bitcode::NUM_ROUND_F32[width]),
F64 => self.load_args_and_call_zig(backend, &bitcode::NUM_ROUND_F64[width]),
Decimal => self.load_args_and_call_zig(backend, &bitcode::DEC_ROUND[width]),
_ => internal_error!("Invalid argument type for round: {:?}", arg_type),
}
}
@ -1643,6 +1644,14 @@ impl<'a> LowLevelCall<'a> {
self.load_args(backend);
let arg_type = CodeGenNumType::for_symbol(backend, self.arguments[0]);
let ret_type = CodeGenNumType::from(self.ret_layout);
let width = match ret_type {
CodeGenNumType::I32 => IntWidth::I32,
CodeGenNumType::I64 => IntWidth::I64,
CodeGenNumType::I128 => todo!("{:?} for I128", self.lowlevel),
_ => internal_error!("Invalid return type for round: {:?}", ret_type),
};
match (arg_type, self.lowlevel) {
(F32, NumCeiling) => {
backend.code_builder.f32_ceil();
@ -1650,14 +1659,21 @@ impl<'a> LowLevelCall<'a> {
(F64, NumCeiling) => {
backend.code_builder.f64_ceil();
}
(Decimal, NumCeiling) => {
return self.load_args_and_call_zig(backend, &bitcode::DEC_CEILING[width]);
}
(F32, NumFloor) => {
backend.code_builder.f32_floor();
}
(F64, NumFloor) => {
backend.code_builder.f64_floor();
}
(Decimal, NumFloor) => {
return self.load_args_and_call_zig(backend, &bitcode::DEC_FLOOR[width]);
}
_ => internal_error!("Invalid argument type for ceiling: {:?}", arg_type),
}
match (ret_type, arg_type) {
// TODO: unsigned truncation
(I32, F32) => backend.code_builder.i32_trunc_s_f32(),

View file

@ -1831,15 +1831,56 @@ fn pow() {
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn ceiling() {
assert_evals_to!("Num.ceiling 1.1f64", 2, i64);
fn round_f64() {
assert_evals_to!("Num.round 1.9f64", 2, i64);
assert_evals_to!("Num.round -1.9f64", -2, i64);
assert_evals_to!("Num.round 0.5f64", 1, i64);
assert_evals_to!("Num.round -0.5f64", -1, i64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn floor() {
fn round_dec() {
assert_evals_to!("Num.round 1.9dec", 2, i64);
assert_evals_to!("Num.round -1.9dec", -2, i64);
assert_evals_to!("Num.round 0.5dec", 1, i64);
assert_evals_to!("Num.round -0.5dec", -1, i64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn ceiling_f64() {
assert_evals_to!("Num.ceiling 1.9f64", 2, i64);
assert_evals_to!("Num.ceiling -1.9f64", -1, i64);
assert_evals_to!("Num.ceiling 0.5f64", 1, i64);
assert_evals_to!("Num.ceiling -0.5f64", 0, i64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn ceiling_dec() {
assert_evals_to!("Num.ceiling 1.9dec", 2, i64);
assert_evals_to!("Num.ceiling -1.9dec", -1, i64);
assert_evals_to!("Num.ceiling 0.5dec", 1, i64);
assert_evals_to!("Num.ceiling -0.5dec", 0, i64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn floor_f64() {
assert_evals_to!("Num.floor 1.9f64", 1, i64);
assert_evals_to!("Num.floor -1.9f64", -2, i64);
assert_evals_to!("Num.floor 0.5f64", 0, i64);
assert_evals_to!("Num.floor -0.5f64", -1, i64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn floor_dec() {
assert_evals_to!("Num.floor 1.9dec", 1, i64);
assert_evals_to!("Num.floor -1.9dec", -2, i64);
assert_evals_to!("Num.floor 0.5dec", 0, i64);
assert_evals_to!("Num.floor -0.5dec", -1, i64);
}
#[test]