implement sqrt and log in the dev backend

This commit is contained in:
Folkert 2023-02-26 19:05:33 +01:00
parent 6a40d75353
commit 44f08f9e47
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
7 changed files with 168 additions and 88 deletions

View file

@ -748,6 +748,7 @@ sqrtChecked = \x ->
else
Ok (Num.sqrt x)
## Natural logarithm
log : Frac a -> Frac a
logChecked : Frac a -> Result (Frac a) [LogNeedsPositive]

View file

@ -1053,6 +1053,14 @@ impl Assembler<AArch64GeneralReg, AArch64FloatReg> for AArch64Assembler {
{
todo!("sar for AArch64")
}
fn sqrt_freg64_freg64(_buf: &mut Vec<'_, u8>, _dst: AArch64FloatReg, _src: AArch64FloatReg) {
todo!("sqrt")
}
fn sqrt_freg32_freg32(_buf: &mut Vec<'_, u8>, _dst: AArch64FloatReg, _src: AArch64FloatReg) {
todo!("sqrt")
}
}
impl AArch64Assembler {}

View file

@ -317,6 +317,9 @@ pub trait Assembler<GeneralReg: RegTrait, FloatReg: RegTrait>: Sized + Copy {
fn mov_stack32_freg64(buf: &mut Vec<'_, u8>, offset: i32, src: FloatReg);
fn mov_stack32_reg64(buf: &mut Vec<'_, u8>, offset: i32, src: GeneralReg);
fn sqrt_freg64_freg64(buf: &mut Vec<'_, u8>, dst: FloatReg, src: FloatReg);
fn sqrt_freg32_freg32(buf: &mut Vec<'_, u8>, dst: FloatReg, src: FloatReg);
fn neg_reg64_reg64(buf: &mut Vec<'_, u8>, dst: GeneralReg, src: GeneralReg);
fn mul_freg32_freg32_freg32(
buf: &mut Vec<'_, u8>,
@ -2571,6 +2574,18 @@ impl<
}
}
}
fn build_num_sqrt(&mut self, dst: Symbol, src: Symbol, float_width: FloatWidth) {
let buf = &mut self.buf;
let dst_reg = self.storage_manager.claim_float_reg(buf, &dst);
let src_reg = self.storage_manager.load_to_float_reg(buf, &src);
match float_width {
FloatWidth::F32 => ASM::sqrt_freg32_freg32(buf, dst_reg, src_reg),
FloatWidth::F64 => ASM::sqrt_freg64_freg64(buf, dst_reg, src_reg),
}
}
}
/// This impl block is for ir related instructions that need backend specific information.

View file

@ -1745,6 +1745,14 @@ impl Assembler<X86_64GeneralReg, X86_64FloatReg> for X86_64Assembler {
{
shift_reg64_reg64_reg64(buf, storage_manager, sar_reg64_reg64, dst, src1, src2)
}
fn sqrt_freg64_freg64(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
sqrtsd_freg64_freg64(buf, dst, src)
}
fn sqrt_freg32_freg32(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
sqrtss_freg32_freg32(buf, dst, src)
}
}
fn shift_reg64_reg64_reg64<'a, 'r, ASM, CC>(
@ -2157,6 +2165,48 @@ fn cmp_freg32_freg32(buf: &mut Vec<'_, u8>, src1: X86_64FloatReg, src2: X86_64Fl
}
}
#[inline(always)]
fn sqrtsd_freg64_freg64(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
let dst_high = dst as u8 > 7;
let dst_mod = dst as u8 % 8;
let src_high = src as u8 > 7;
let src_mod = src as u8 % 8;
if dst_high || src_high {
buf.extend([
0xF2,
0x40 | ((dst_high as u8) << 2) | (src_high as u8),
0x0F,
0x51,
0xC0 | (dst_mod << 3) | (src_mod),
])
} else {
buf.extend([0xF2, 0x0F, 0x51, 0xC0 | (dst_mod << 3) | (src_mod)])
}
}
#[inline(always)]
fn sqrtss_freg32_freg32(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) {
let dst_high = dst as u8 > 7;
let dst_mod = dst as u8 % 8;
let src_high = src as u8 > 7;
let src_mod = src as u8 % 8;
if dst_high || src_high {
buf.extend([
0xF3,
0x40 | ((dst_high as u8) << 2) | (src_high as u8),
0x0F,
0x51,
0xC0 | (dst_mod << 3) | (src_mod),
])
} else {
buf.extend([0xF3, 0x0F, 0x51, 0xC0 | (dst_mod << 3) | (src_mod)])
}
}
/// `TEST r/m64,r64` -> AND r64 with r/m64; set SF, ZF, PF according to result.
#[allow(dead_code)]
#[inline(always)]
@ -3601,4 +3651,24 @@ mod tests {
fn test_push_reg64() {
disassembler_test!(push_reg64, |reg| format!("push {}", reg), ALL_GENERAL_REGS);
}
#[test]
fn test_sqrt_freg64_freg64() {
disassembler_test!(
sqrtsd_freg64_freg64,
|dst, src| format!("sqrtsd {dst}, {src}"),
ALL_FLOAT_REGS,
ALL_FLOAT_REGS
);
}
#[test]
fn test_sqrt_freg32_freg32() {
disassembler_test!(
sqrtss_freg32_freg32,
|dst, src| format!("sqrtss {dst}, {src}"),
ALL_FLOAT_REGS,
ALL_FLOAT_REGS
);
}
}

View file

@ -751,6 +751,30 @@ trait Backend<'a> {
);
self.build_num_gte(sym, &args[0], &args[1], &arg_layouts[0])
}
LowLevel::NumLogUnchecked => {
let float_width = match arg_layouts[0] {
Layout::F64 => FloatWidth::F64,
Layout::F32 => FloatWidth::F32,
_ => unreachable!("invalid layout for sqrt"),
};
self.build_fn_call(
sym,
bitcode::NUM_LOG[float_width].to_string(),
args,
arg_layouts,
ret_layout,
)
}
LowLevel::NumSqrtUnchecked => {
let float_width = match arg_layouts[0] {
Layout::F64 => FloatWidth::F64,
Layout::F32 => FloatWidth::F32,
_ => unreachable!("invalid layout for sqrt"),
};
self.build_num_sqrt(*sym, args[0], float_width);
}
LowLevel::NumRound => self.build_fn_call(
sym,
bitcode::NUM_ROUND_F64[IntWidth::I64].to_string(),
@ -1261,6 +1285,9 @@ trait Backend<'a> {
arg_layout: &InLayout<'a>,
);
/// build_sqrt stores the result of `sqrt(src)` into dst.
fn build_num_sqrt(&mut self, dst: Symbol, src: Symbol, float_width: FloatWidth);
/// build_list_len returns the length of a list.
fn build_list_len(&mut self, dst: &Symbol, list: &Symbol);

View file

@ -468,112 +468,51 @@ fn f32_float_alias() {
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_sqrt() {
assert_evals_to!(
indoc!(
r#"
when Num.sqrtChecked 100 is
Ok val -> val
Err _ -> -1
"#
),
10.0,
f64
);
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn f64_sqrt_100() {
assert_evals_to!("Num.sqrt 100", 10.0, f64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn f64_sqrt_checked_0() {
assert_evals_to!("Num.sqrt 0", 0.0, f64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_log() {
assert_evals_to!(
indoc!(
r#"
Num.log 7.38905609893
"#
),
1.999999999999912,
f64
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_log_checked_one() {
assert_evals_to!(
indoc!(
r#"
when Num.logChecked 1 is
Ok val -> val
Err _ -> -1
"#
),
0.0,
f64
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_sqrt_zero() {
assert_evals_to!(
indoc!(
r#"
when Num.sqrtChecked 0 is
Ok val -> val
Err _ -> -1
"#
),
0.0,
f64
);
fn f64_sqrt_checked_positive() {
assert_evals_to!("Num.sqrtChecked 100", RocResult::ok(10.0), RocResult<f64, ()>);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_sqrt_checked_negative() {
assert_evals_to!(
indoc!(
r#"
when Num.sqrtChecked -1 is
Err _ -> 42
Ok val -> val
"#
),
42.0,
f64
);
assert_evals_to!("Num.sqrtChecked -1f64", RocResult::err(()), RocResult<f64, ()>);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn f64_log() {
assert_evals_to!("Num.log 7.38905609893", 1.999999999999912, f64);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_log_checked_one() {
assert_evals_to!("Num.logChecked 1", RocResult::ok(1.0), RocResult<f64, ()>);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_log_checked_zero() {
assert_evals_to!(
indoc!(
r#"
when Num.logChecked 0 is
Err _ -> 42
Ok val -> val
"#
),
42.0,
f64
);
assert_evals_to!("Num.logChecked 0", RocResult::err(()), RocResult<f64, ()>);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn f64_log_negative() {
assert_evals_to!(
indoc!(
r#"
Num.log -1
"#
),
true,
f64,
|f: f64| f.is_nan()
);
assert_evals_to!("Num.log -1", true, f64, |f: f64| f.is_nan());
}
#[test]

View file

@ -227,7 +227,7 @@ fn is_err() {
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn roc_result_ok() {
fn roc_result_ok_i64() {
assert_evals_to!(
indoc!(
r#"
@ -242,6 +242,26 @@ fn roc_result_ok() {
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn roc_result_ok_f64() {
// NOTE: the dev backend does not currently use float registers when returning a more
// complex type, but the rust side does expect it to. Hence this test fails with gen-dev
assert_evals_to!(
indoc!(
r#"
result : Result F64 {}
result = Ok 42.0
result
"#
),
RocResult::ok(42.0),
RocResult<f64, ()>
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn roc_result_err() {