mirror of
https://github.com/roc-lang/roc.git
synced 2025-07-24 15:03:46 +00:00
implement sqrt and log in the dev backend
This commit is contained in:
parent
6a40d75353
commit
44f08f9e47
7 changed files with 168 additions and 88 deletions
|
@ -748,6 +748,7 @@ sqrtChecked = \x ->
|
|||
else
|
||||
Ok (Num.sqrt x)
|
||||
|
||||
## Natural logarithm
|
||||
log : Frac a -> Frac a
|
||||
|
||||
logChecked : Frac a -> Result (Frac a) [LogNeedsPositive]
|
||||
|
|
|
@ -1053,6 +1053,14 @@ impl Assembler<AArch64GeneralReg, AArch64FloatReg> for AArch64Assembler {
|
|||
{
|
||||
todo!("sar for AArch64")
|
||||
}
|
||||
|
||||
fn sqrt_freg64_freg64(_buf: &mut Vec<'_, u8>, _dst: AArch64FloatReg, _src: AArch64FloatReg) {
|
||||
todo!("sqrt")
|
||||
}
|
||||
|
||||
fn sqrt_freg32_freg32(_buf: &mut Vec<'_, u8>, _dst: AArch64FloatReg, _src: AArch64FloatReg) {
|
||||
todo!("sqrt")
|
||||
}
|
||||
}
|
||||
|
||||
impl AArch64Assembler {}
|
||||
|
|
|
@ -317,6 +317,9 @@ pub trait Assembler<GeneralReg: RegTrait, FloatReg: RegTrait>: Sized + Copy {
|
|||
fn mov_stack32_freg64(buf: &mut Vec<'_, u8>, offset: i32, src: FloatReg);
|
||||
fn mov_stack32_reg64(buf: &mut Vec<'_, u8>, offset: i32, src: GeneralReg);
|
||||
|
||||
fn sqrt_freg64_freg64(buf: &mut Vec<'_, u8>, dst: FloatReg, src: FloatReg);
|
||||
fn sqrt_freg32_freg32(buf: &mut Vec<'_, u8>, dst: FloatReg, src: FloatReg);
|
||||
|
||||
fn neg_reg64_reg64(buf: &mut Vec<'_, u8>, dst: GeneralReg, src: GeneralReg);
|
||||
fn mul_freg32_freg32_freg32(
|
||||
buf: &mut Vec<'_, u8>,
|
||||
|
@ -2571,6 +2574,18 @@ impl<
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_num_sqrt(&mut self, dst: Symbol, src: Symbol, float_width: FloatWidth) {
|
||||
let buf = &mut self.buf;
|
||||
|
||||
let dst_reg = self.storage_manager.claim_float_reg(buf, &dst);
|
||||
let src_reg = self.storage_manager.load_to_float_reg(buf, &src);
|
||||
|
||||
match float_width {
|
||||
FloatWidth::F32 => ASM::sqrt_freg32_freg32(buf, dst_reg, src_reg),
|
||||
FloatWidth::F64 => ASM::sqrt_freg64_freg64(buf, dst_reg, src_reg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This impl block is for ir related instructions that need backend specific information.
|
||||
|
|
|
@ -1745,6 +1745,14 @@ impl Assembler<X86_64GeneralReg, X86_64FloatReg> for X86_64Assembler {
|
|||
{
|
||||
shift_reg64_reg64_reg64(buf, storage_manager, sar_reg64_reg64, dst, src1, src2)
|
||||
}
|
||||
|
||||
fn sqrt_freg64_freg64(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
|
||||
sqrtsd_freg64_freg64(buf, dst, src)
|
||||
}
|
||||
|
||||
fn sqrt_freg32_freg32(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
|
||||
sqrtss_freg32_freg32(buf, dst, src)
|
||||
}
|
||||
}
|
||||
|
||||
fn shift_reg64_reg64_reg64<'a, 'r, ASM, CC>(
|
||||
|
@ -2157,6 +2165,48 @@ fn cmp_freg32_freg32(buf: &mut Vec<'_, u8>, src1: X86_64FloatReg, src2: X86_64Fl
|
|||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sqrtsd_freg64_freg64(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
|
||||
let dst_high = dst as u8 > 7;
|
||||
let dst_mod = dst as u8 % 8;
|
||||
|
||||
let src_high = src as u8 > 7;
|
||||
let src_mod = src as u8 % 8;
|
||||
|
||||
if dst_high || src_high {
|
||||
buf.extend([
|
||||
0xF2,
|
||||
0x40 | ((dst_high as u8) << 2) | (src_high as u8),
|
||||
0x0F,
|
||||
0x51,
|
||||
0xC0 | (dst_mod << 3) | (src_mod),
|
||||
])
|
||||
} else {
|
||||
buf.extend([0xF2, 0x0F, 0x51, 0xC0 | (dst_mod << 3) | (src_mod)])
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sqrtss_freg32_freg32(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
|
||||
let dst_high = dst as u8 > 7;
|
||||
let dst_mod = dst as u8 % 8;
|
||||
|
||||
let src_high = src as u8 > 7;
|
||||
let src_mod = src as u8 % 8;
|
||||
|
||||
if dst_high || src_high {
|
||||
buf.extend([
|
||||
0xF3,
|
||||
0x40 | ((dst_high as u8) << 2) | (src_high as u8),
|
||||
0x0F,
|
||||
0x51,
|
||||
0xC0 | (dst_mod << 3) | (src_mod),
|
||||
])
|
||||
} else {
|
||||
buf.extend([0xF3, 0x0F, 0x51, 0xC0 | (dst_mod << 3) | (src_mod)])
|
||||
}
|
||||
}
|
||||
|
||||
/// `TEST r/m64,r64` -> AND r64 with r/m64; set SF, ZF, PF according to result.
|
||||
#[allow(dead_code)]
|
||||
#[inline(always)]
|
||||
|
@ -3601,4 +3651,24 @@ mod tests {
|
|||
fn test_push_reg64() {
|
||||
disassembler_test!(push_reg64, |reg| format!("push {}", reg), ALL_GENERAL_REGS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt_freg64_freg64() {
|
||||
disassembler_test!(
|
||||
sqrtsd_freg64_freg64,
|
||||
|dst, src| format!("sqrtsd {dst}, {src}"),
|
||||
ALL_FLOAT_REGS,
|
||||
ALL_FLOAT_REGS
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt_freg32_freg32() {
|
||||
disassembler_test!(
|
||||
sqrtss_freg32_freg32,
|
||||
|dst, src| format!("sqrtss {dst}, {src}"),
|
||||
ALL_FLOAT_REGS,
|
||||
ALL_FLOAT_REGS
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -751,6 +751,30 @@ trait Backend<'a> {
|
|||
);
|
||||
self.build_num_gte(sym, &args[0], &args[1], &arg_layouts[0])
|
||||
}
|
||||
LowLevel::NumLogUnchecked => {
|
||||
let float_width = match arg_layouts[0] {
|
||||
Layout::F64 => FloatWidth::F64,
|
||||
Layout::F32 => FloatWidth::F32,
|
||||
_ => unreachable!("invalid layout for sqrt"),
|
||||
};
|
||||
|
||||
self.build_fn_call(
|
||||
sym,
|
||||
bitcode::NUM_LOG[float_width].to_string(),
|
||||
args,
|
||||
arg_layouts,
|
||||
ret_layout,
|
||||
)
|
||||
}
|
||||
LowLevel::NumSqrtUnchecked => {
|
||||
let float_width = match arg_layouts[0] {
|
||||
Layout::F64 => FloatWidth::F64,
|
||||
Layout::F32 => FloatWidth::F32,
|
||||
_ => unreachable!("invalid layout for sqrt"),
|
||||
};
|
||||
|
||||
self.build_num_sqrt(*sym, args[0], float_width);
|
||||
}
|
||||
LowLevel::NumRound => self.build_fn_call(
|
||||
sym,
|
||||
bitcode::NUM_ROUND_F64[IntWidth::I64].to_string(),
|
||||
|
@ -1261,6 +1285,9 @@ trait Backend<'a> {
|
|||
arg_layout: &InLayout<'a>,
|
||||
);
|
||||
|
||||
/// build_sqrt stores the result of `sqrt(src)` into dst.
|
||||
fn build_num_sqrt(&mut self, dst: Symbol, src: Symbol, float_width: FloatWidth);
|
||||
|
||||
/// build_list_len returns the length of a list.
|
||||
fn build_list_len(&mut self, dst: &Symbol, list: &Symbol);
|
||||
|
||||
|
|
|
@ -468,112 +468,51 @@ fn f32_float_alias() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_sqrt() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when Num.sqrtChecked 100 is
|
||||
Ok val -> val
|
||||
Err _ -> -1
|
||||
"#
|
||||
),
|
||||
10.0,
|
||||
f64
|
||||
);
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
|
||||
fn f64_sqrt_100() {
|
||||
assert_evals_to!("Num.sqrt 100", 10.0, f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
|
||||
fn f64_sqrt_checked_0() {
|
||||
assert_evals_to!("Num.sqrt 0", 0.0, f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_log() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
Num.log 7.38905609893
|
||||
"#
|
||||
),
|
||||
1.999999999999912,
|
||||
f64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_log_checked_one() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when Num.logChecked 1 is
|
||||
Ok val -> val
|
||||
Err _ -> -1
|
||||
"#
|
||||
),
|
||||
0.0,
|
||||
f64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_sqrt_zero() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when Num.sqrtChecked 0 is
|
||||
Ok val -> val
|
||||
Err _ -> -1
|
||||
"#
|
||||
),
|
||||
0.0,
|
||||
f64
|
||||
);
|
||||
fn f64_sqrt_checked_positive() {
|
||||
assert_evals_to!("Num.sqrtChecked 100", RocResult::ok(10.0), RocResult<f64, ()>);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_sqrt_checked_negative() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when Num.sqrtChecked -1 is
|
||||
Err _ -> 42
|
||||
Ok val -> val
|
||||
"#
|
||||
),
|
||||
42.0,
|
||||
f64
|
||||
);
|
||||
assert_evals_to!("Num.sqrtChecked -1f64", RocResult::err(()), RocResult<f64, ()>);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
|
||||
fn f64_log() {
|
||||
assert_evals_to!("Num.log 7.38905609893", 1.999999999999912, f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_log_checked_one() {
|
||||
assert_evals_to!("Num.logChecked 1", RocResult::ok(1.0), RocResult<f64, ()>);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_log_checked_zero() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
when Num.logChecked 0 is
|
||||
Err _ -> 42
|
||||
Ok val -> val
|
||||
"#
|
||||
),
|
||||
42.0,
|
||||
f64
|
||||
);
|
||||
assert_evals_to!("Num.logChecked 0", RocResult::err(()), RocResult<f64, ()>);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn f64_log_negative() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
Num.log -1
|
||||
"#
|
||||
),
|
||||
true,
|
||||
f64,
|
||||
|f: f64| f.is_nan()
|
||||
);
|
||||
assert_evals_to!("Num.log -1", true, f64, |f: f64| f.is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -227,7 +227,7 @@ fn is_err() {
|
|||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
|
||||
fn roc_result_ok() {
|
||||
fn roc_result_ok_i64() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
|
@ -242,6 +242,26 @@ fn roc_result_ok() {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn roc_result_ok_f64() {
|
||||
// NOTE: the dev backend does not currently use float registers when returning a more
|
||||
// complex type, but the rust side does expect it to. Hence this test fails with gen-dev
|
||||
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
result : Result F64 {}
|
||||
result = Ok 42.0
|
||||
|
||||
result
|
||||
"#
|
||||
),
|
||||
RocResult::ok(42.0),
|
||||
RocResult<f64, ()>
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
|
||||
fn roc_result_err() {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue