Num.pow for Dec

This commit is contained in:
Folkert 2024-01-29 20:41:14 +01:00
parent 2e648cfdd5
commit e16b25c93e
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
7 changed files with 125 additions and 2 deletions

View file

@ -391,6 +391,29 @@ pub const RocDec = extern struct {
} }
} }
fn powInt(base: RocDec, exponent: i128) RocDec {
if (exponent == 0) {
return RocDec.one_point_zero;
} else if (exponent > 0) {
if (@mod(exponent, 2) == 0) {
const half_power = RocDec.powInt(base, exponent >> 1); // `>> 1` == `/ 2`
return RocDec.mul(half_power, half_power);
} else {
return RocDec.mul(base, RocDec.powInt(base, exponent - 1));
}
} else {
return RocDec.div(RocDec.one_point_zero, RocDec.powInt(base, -exponent));
}
}
fn pow(base: RocDec, exponent: RocDec) RocDec {
if (exponent.trunc().num == exponent.num) {
return base.powInt(@divTrunc(exponent.num, RocDec.one_point_zero_i128));
} else {
return fromF64(std.math.pow(f64, base.toF64(), exponent.toF64())).?;
}
}
pub fn mul(self: RocDec, other: RocDec) RocDec { pub fn mul(self: RocDec, other: RocDec) RocDec {
const answer = RocDec.mulWithOverflow(self, other); const answer = RocDec.mulWithOverflow(self, other);
@ -1358,6 +1381,41 @@ test "round: -0.5" {
try expectEqual(RocDec{ .num = -1000000000000000000 }, dec.round()); try expectEqual(RocDec{ .num = -1000000000000000000 }, dec.round());
} }
test "powInt: 3.1 ^ 0" {
var roc_str = RocStr.init("3.1", 3);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(RocDec.one_point_zero, dec.powInt(0));
}
test "powInt: 3.1 ^ 1" {
var roc_str = RocStr.init("3.1", 3);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(dec, dec.powInt(1));
}
test "powInt: 2 ^ 2" {
var roc_str = RocStr.init("4", 1);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(dec, RocDec.two_point_zero.powInt(2));
}
test "powInt: 0.5 ^ 2" {
var roc_str = RocStr.init("0.25", 4);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(dec, RocDec.zero_point_five.powInt(2));
}
test "pow: 0.5 ^ 2.0" {
var roc_str = RocStr.init("0.25", 4);
var dec = RocDec.fromStr(roc_str).?;
try expectEqual(dec, RocDec.zero_point_five.pow(RocDec.two_point_zero));
}
// exports // exports
pub fn fromStr(arg: RocStr) callconv(.C) num_.NumParseResult(i128) { pub fn fromStr(arg: RocStr) callconv(.C) num_.NumParseResult(i128) {
@ -1458,6 +1516,10 @@ pub fn logC(arg: RocDec) callconv(.C) i128 {
return @call(.always_inline, RocDec.log, .{arg}).num; return @call(.always_inline, RocDec.log, .{arg}).num;
} }
pub fn powC(arg1: RocDec, arg2: RocDec) callconv(.C) i128 {
return @call(.always_inline, RocDec.pow, .{ arg1, arg2 }).num;
}
pub fn sinC(arg: RocDec) callconv(.C) i128 { pub fn sinC(arg: RocDec) callconv(.C) i128 {
return @call(.always_inline, RocDec.sin, .{arg}).num; return @call(.always_inline, RocDec.sin, .{arg}).num;
} }

View file

@ -36,6 +36,7 @@ comptime {
exportDecFn(dec.fromStr, "from_str"); exportDecFn(dec.fromStr, "from_str");
exportDecFn(dec.fromU64C, "from_u64"); exportDecFn(dec.fromU64C, "from_u64");
exportDecFn(dec.logC, "log"); exportDecFn(dec.logC, "log");
exportDecFn(dec.powC, "pow");
exportDecFn(dec.mulC, "mul_with_overflow"); exportDecFn(dec.mulC, "mul_with_overflow");
exportDecFn(dec.mulOrPanicC, "mul_or_panic"); exportDecFn(dec.mulOrPanicC, "mul_or_panic");
exportDecFn(dec.mulSaturatedC, "mul_saturated"); exportDecFn(dec.mulSaturatedC, "mul_saturated");

View file

@ -408,6 +408,7 @@ pub const DEC_FROM_INT: IntrinsicName = int_intrinsic!("roc_builtins.dec.from_in
pub const DEC_FROM_STR: &str = "roc_builtins.dec.from_str"; pub const DEC_FROM_STR: &str = "roc_builtins.dec.from_str";
pub const DEC_FROM_U64: &str = "roc_builtins.dec.from_u64"; pub const DEC_FROM_U64: &str = "roc_builtins.dec.from_u64";
pub const DEC_LOG: &str = "roc_builtins.dec.log"; pub const DEC_LOG: &str = "roc_builtins.dec.log";
pub const DEC_POW: &str = "roc_builtins.dec.pow";
pub const DEC_MUL_OR_PANIC: &str = "roc_builtins.dec.mul_or_panic"; pub const DEC_MUL_OR_PANIC: &str = "roc_builtins.dec.mul_or_panic";
pub const DEC_MUL_SATURATED: &str = "roc_builtins.dec.mul_saturated"; pub const DEC_MUL_SATURATED: &str = "roc_builtins.dec.mul_saturated";
pub const DEC_MUL_WITH_OVERFLOW: &str = "roc_builtins.dec.mul_with_overflow"; pub const DEC_MUL_WITH_OVERFLOW: &str = "roc_builtins.dec.mul_with_overflow";

View file

@ -1099,7 +1099,7 @@ trait Backend<'a> {
LayoutRepr::Builtin(Builtin::Float(float_width)) => { LayoutRepr::Builtin(Builtin::Float(float_width)) => {
&bitcode::NUM_POW[float_width] &bitcode::NUM_POW[float_width]
} }
LayoutRepr::DEC => todo!("exponentiation for decimals"), LayoutRepr::DEC => bitcode::DEC_POW,
_ => unreachable!("invalid layout for NumPow"), _ => unreachable!("invalid layout for NumPow"),
}; };

View file

@ -2059,6 +2059,54 @@ fn dec_unary_op<'ctx>(
} }
} }
fn dec_binary_op<'ctx>(
env: &Env<'_, 'ctx, '_>,
fn_name: &str,
dec1: BasicValueEnum<'ctx>,
dec2: BasicValueEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
use roc_target::Architecture::*;
use roc_target::OperatingSystem::*;
let dec1 = dec1.into_int_value();
let dec2 = dec2.into_int_value();
match env.target_info {
TargetInfo {
architecture: X86_64 | X86_32,
operating_system: Unix,
} => {
let (low1, high1) = dec_split_into_words(env, dec1);
let (low2, high2) = dec_split_into_words(env, dec2);
let lowr_highr = call_bitcode_fn(
env,
&[low1.into(), high1.into(), low2.into(), high2.into()],
fn_name,
);
let block = env.builder.get_insert_block().expect("to be in a function");
let parent = block.get_parent().expect("to be in a function");
let ptr =
create_entry_block_alloca(env, parent, env.context.i128_type().into(), "to_i128");
env.builder.build_store(ptr, lowr_highr).unwrap();
env.builder
.build_load(env.context.i128_type(), ptr, "to_i128")
.unwrap()
}
TargetInfo {
architecture: Wasm32,
operating_system: Unix,
} => call_bitcode_fn(env, &[dec1.into(), dec2.into()], fn_name),
_ => call_bitcode_fn(
env,
&[dec_alloca(env, dec1), dec_alloca(env, dec2)],
fn_name,
),
}
}
fn dec_binop_with_overflow<'ctx>( fn dec_binop_with_overflow<'ctx>(
env: &Env<'_, 'ctx, '_>, env: &Env<'_, 'ctx, '_>,
fn_name: &str, fn_name: &str,
@ -2277,6 +2325,7 @@ fn build_dec_binop<'a, 'ctx>(
&[lhs, rhs], &[lhs, rhs],
&bitcode::NUM_GREATER_THAN_OR_EQUAL[IntWidth::I128], &bitcode::NUM_GREATER_THAN_OR_EQUAL[IntWidth::I128],
), ),
NumPow => dec_binary_op(env, bitcode::DEC_POW, lhs, rhs),
_ => { _ => {
unreachable!("Unrecognized dec binary operation: {:?}", op); unreachable!("Unrecognized dec binary operation: {:?}", op);
} }

View file

@ -1585,6 +1585,10 @@ impl<'a> LowLevelCall<'a> {
LayoutRepr::Builtin(Builtin::Float(width)) => { LayoutRepr::Builtin(Builtin::Float(width)) => {
self.load_args_and_call_zig(backend, &bitcode::NUM_POW[width]); self.load_args_and_call_zig(backend, &bitcode::NUM_POW[width]);
} }
LayoutRepr::Builtin(Builtin::Decimal) => {
self.load_args_and_call_zig(backend, bitcode::DEC_POW);
}
_ => panic_ret_type(), _ => panic_ret_type(),
}, },

View file

@ -1819,10 +1819,16 @@ fn float_compare() {
#[test] #[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn pow() { fn pow_f64() {
assert_evals_to!("Num.pow 2.0f64 2.0f64", 4.0, f64); assert_evals_to!("Num.pow 2.0f64 2.0f64", 4.0, f64);
} }
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn pow_dec() {
assert_evals_to!("Num.pow 2.0dec 2.0dec", RocDec::from(4), RocDec);
}
#[test] #[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn round_f64() { fn round_f64() {