float comparisions

This commit is contained in:
Folkert 2023-02-26 17:37:00 +01:00
parent f1fa014524
commit 6a40d75353
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
4 changed files with 219 additions and 17 deletions

View file

@ -2,10 +2,13 @@ use crate::generic64::{storage::StorageManager, Assembler, CallConv, RegTrait};
use crate::Relocation;
use bumpalo::collections::Vec;
use packed_struct::prelude::*;
use roc_builtins::bitcode::FloatWidth;
use roc_error_macros::internal_error;
use roc_module::symbol::Symbol;
use roc_mono::layout::{InLayout, STLayoutInterner};
use super::CompareOperation;
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
#[allow(dead_code)]
pub enum AArch64GeneralReg {
@ -888,6 +891,18 @@ impl Assembler<AArch64GeneralReg, AArch64FloatReg> for AArch64Assembler {
todo!("registers unsigned less than for AArch64");
}
#[inline(always)]
fn cmp_freg_freg_reg64(
_buf: &mut Vec<'_, u8>,
_dst: AArch64GeneralReg,
_src1: AArch64FloatReg,
_src2: AArch64FloatReg,
_width: FloatWidth,
_operation: CompareOperation,
) {
todo!("registers float comparison for AArch64");
}
#[inline(always)]
fn igt_reg64_reg64_reg64(
_buf: &mut Vec<'_, u8>,

View file

@ -113,6 +113,13 @@ pub trait CallConv<GeneralReg: RegTrait, FloatReg: RegTrait, ASM: Assembler<Gene
);
}
pub enum CompareOperation {
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
}
/// Assembler contains calls to the backend assembly generator.
/// These calls do not necessarily map directly to a single assembly instruction.
/// They are higher level in cases where an instruction would not be common and shared between multiple architectures.
@ -406,6 +413,15 @@ pub trait Assembler<GeneralReg: RegTrait, FloatReg: RegTrait>: Sized + Copy {
src2: GeneralReg,
);
fn cmp_freg_freg_reg64(
buf: &mut Vec<'_, u8>,
dst: GeneralReg,
src1: FloatReg,
src2: FloatReg,
width: FloatWidth,
operation: CompareOperation,
);
fn igt_reg64_reg64_reg64(
buf: &mut Vec<'_, u8>,
dst: GeneralReg,
@ -1273,6 +1289,20 @@ impl<
.load_to_general_reg(&mut self.buf, src2);
ASM::ult_reg64_reg64_reg64(&mut self.buf, dst_reg, src1_reg, src2_reg);
}
Layout::Builtin(Builtin::Float(width)) => {
let dst_reg = self.storage_manager.claim_general_reg(&mut self.buf, dst);
let src1_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src1);
let src2_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src2);
ASM::cmp_freg_freg_reg64(
&mut self.buf,
dst_reg,
src1_reg,
src2_reg,
width,
CompareOperation::LessThan,
);
}
x => todo!("NumLt: layout, {:?}", x),
}
}
@ -1305,6 +1335,20 @@ impl<
.load_to_general_reg(&mut self.buf, src2);
ASM::ugt_reg64_reg64_reg64(&mut self.buf, dst_reg, src1_reg, src2_reg);
}
Layout::Builtin(Builtin::Float(width)) => {
let dst_reg = self.storage_manager.claim_general_reg(&mut self.buf, dst);
let src1_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src1);
let src2_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src2);
ASM::cmp_freg_freg_reg64(
&mut self.buf,
dst_reg,
src1_reg,
src2_reg,
width,
CompareOperation::GreaterThan,
);
}
x => todo!("NumGt: layout, {:?}", x),
}
}
@ -1385,6 +1429,26 @@ impl<
.load_to_general_reg(&mut self.buf, src2);
ASM::lte_reg64_reg64_reg64(&mut self.buf, dst_reg, src1_reg, src2_reg);
}
Layout::F64 | Layout::F32 => {
let width = if *arg_layout == Layout::F64 {
FloatWidth::F64
} else {
FloatWidth::F32
};
let dst_reg = self.storage_manager.claim_general_reg(&mut self.buf, dst);
let src1_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src1);
let src2_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src2);
ASM::cmp_freg_freg_reg64(
&mut self.buf,
dst_reg,
src1_reg,
src2_reg,
width,
CompareOperation::LessThanOrEqual,
);
}
x => todo!("NumLte: layout, {:?}", x),
}
}
@ -1407,6 +1471,26 @@ impl<
.load_to_general_reg(&mut self.buf, src2);
ASM::gte_reg64_reg64_reg64(&mut self.buf, dst_reg, src1_reg, src2_reg);
}
Layout::F64 | Layout::F32 => {
let width = if *arg_layout == Layout::F64 {
FloatWidth::F64
} else {
FloatWidth::F32
};
let dst_reg = self.storage_manager.claim_general_reg(&mut self.buf, dst);
let src1_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src1);
let src2_reg = self.storage_manager.load_to_float_reg(&mut self.buf, src2);
ASM::cmp_freg_freg_reg64(
&mut self.buf,
dst_reg,
src1_reg,
src2_reg,
width,
CompareOperation::GreaterThanOrEqual,
);
}
x => todo!("NumGte: layout, {:?}", x),
}
}
@ -2147,12 +2231,14 @@ impl<
let val = *x;
ASM::mov_reg64_imm64(&mut self.buf, reg, i128::from_ne_bytes(val) as i64);
}
(Literal::Int(x), Layout::Builtin(Builtin::Int(IntWidth::I128 | IntWidth::U128))) => {
(
Literal::Int(bytes),
Layout::Builtin(Builtin::Int(IntWidth::I128 | IntWidth::U128)),
) => {
self.storage_manager.with_tmp_general_reg(
&mut self.buf,
|storage_manager, buf, reg| {
let base_offset = storage_manager.claim_stack_area(sym, 16);
let bytes = *x;
let mut num_bytes = [0; 8];
num_bytes.copy_from_slice(&bytes[..8]);
@ -2187,6 +2273,25 @@ impl<
let val = *x as f32;
ASM::mov_freg32_imm32(&mut self.buf, &mut self.relocs, reg, val);
}
(Literal::Decimal(bytes), Layout::Builtin(Builtin::Decimal)) => {
self.storage_manager.with_tmp_general_reg(
&mut self.buf,
|storage_manager, buf, reg| {
let base_offset = storage_manager.claim_stack_area(sym, 16);
let mut num_bytes = [0; 8];
num_bytes.copy_from_slice(&bytes[..8]);
let num = i64::from_ne_bytes(num_bytes);
ASM::mov_reg64_imm64(buf, reg, num);
ASM::mov_base32_reg64(buf, base_offset, reg);
num_bytes.copy_from_slice(&bytes[8..16]);
let num = i64::from_ne_bytes(num_bytes);
ASM::mov_reg64_imm64(buf, reg, num);
ASM::mov_base32_reg64(buf, base_offset + 8, reg);
},
);
}
(Literal::Str(x), Layout::Builtin(Builtin::Str)) => {
if x.len() < 24 {
// Load small string.

View file

@ -4,10 +4,13 @@ use crate::{
single_register_layouts, Relocation,
};
use bumpalo::collections::Vec;
use roc_builtins::bitcode::FloatWidth;
use roc_error_macros::internal_error;
use roc_module::symbol::Symbol;
use roc_mono::layout::{InLayout, Layout, LayoutInterner, STLayoutInterner};
use super::CompareOperation;
// Not sure exactly how I want to represent registers.
// If we want max speed, we would likely make them structs that impl the same trait to avoid ifs.
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
@ -1592,6 +1595,33 @@ impl Assembler<X86_64GeneralReg, X86_64FloatReg> for X86_64Assembler {
setb_reg64(buf, dst);
}
#[inline(always)]
fn cmp_freg_freg_reg64(
buf: &mut Vec<'_, u8>,
dst: X86_64GeneralReg,
src1: X86_64FloatReg,
src2: X86_64FloatReg,
width: FloatWidth,
operation: CompareOperation,
) {
use CompareOperation::*;
let (arg1, arg2) = match operation {
LessThan | LessThanOrEqual => (src1, src2),
GreaterThan | GreaterThanOrEqual => (src2, src1),
};
match width {
FloatWidth::F32 => cmp_freg32_freg32(buf, arg2, arg1),
FloatWidth::F64 => cmp_freg64_freg64(buf, arg2, arg1),
}
match operation {
LessThan | GreaterThan => seta_reg64(buf, dst),
LessThanOrEqual | GreaterThanOrEqual => setae_reg64(buf, dst),
};
}
#[inline(always)]
fn igt_reg64_reg64_reg64(
buf: &mut Vec<'_, u8>,
@ -2085,6 +2115,48 @@ fn cmp_reg64_reg64(buf: &mut Vec<'_, u8>, dst: X86_64GeneralReg, src: X86_64Gene
binop_reg64_reg64(0x39, buf, dst, src);
}
#[inline(always)]
fn cmp_freg64_freg64(buf: &mut Vec<'_, u8>, src1: X86_64FloatReg, src2: X86_64FloatReg) {
let src1_high = src1 as u8 > 7;
let src1_mod = src1 as u8 % 8;
let src2_high = src2 as u8 > 7;
let src2_mod = src2 as u8 % 8;
if src1_high || src2_high {
buf.extend([
0x66,
0x40 | ((src1_high as u8) << 2) | (src2_high as u8),
0x0F,
0x2E,
0xC0 | (src1_mod << 3) | (src2_mod),
])
} else {
buf.extend([0x66, 0x0F, 0x2E, 0xC0 | (src1_mod << 3) | (src2_mod)])
}
}
#[inline(always)]
fn cmp_freg32_freg32(buf: &mut Vec<'_, u8>, src1: X86_64FloatReg, src2: X86_64FloatReg) {
let src1_high = src1 as u8 > 7;
let src1_mod = src1 as u8 % 8;
let src2_high = src2 as u8 > 7;
let src2_mod = src2 as u8 % 8;
if src1_high || src2_high {
buf.extend([
0x65,
0x40 | ((src1_high as u8) << 2) | (src2_high as u8),
0x0F,
0x2E,
0xC0 | (src1_mod << 3) | (src2_mod),
])
} else {
buf.extend([0x65, 0x0F, 0x2E, 0xC0 | (src1_mod << 3) | (src2_mod)])
}
}
/// `TEST r/m64,r64` -> AND r64 with r/m64; set SF, ZF, PF according to result.
#[allow(dead_code)]
#[inline(always)]
@ -2757,6 +2829,12 @@ fn seta_reg64(buf: &mut Vec<'_, u8>, reg: X86_64GeneralReg) {
set_reg64_help(0x97, buf, reg);
}
/// `SETAE r/m64` -> Set byte if above or equal (CF=0).
#[inline(always)]
fn setae_reg64(buf: &mut Vec<'_, u8>, reg: X86_64GeneralReg) {
set_reg64_help(0x93, buf, reg);
}
/// `SETLE r/m64` -> Set byte if less or equal (ZF=1 or SF≠ OF).
#[inline(always)]
fn setle_reg64(buf: &mut Vec<'_, u8>, reg: X86_64GeneralReg) {

View file

@ -115,7 +115,7 @@ fn i8_signed_int_alias() {
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn i128_hex_int_alias() {
assert_evals_to!(
indoc!(
@ -196,7 +196,7 @@ fn i8_hex_int_alias() {
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn u128_signed_int_alias() {
assert_evals_to!(
indoc!(
@ -277,7 +277,7 @@ fn u8_signed_int_alias() {
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn u128_hex_int_alias() {
assert_evals_to!(
indoc!(
@ -418,7 +418,7 @@ fn character_literal_new_line() {
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn dec_float_alias() {
assert_evals_to!(
indoc!(
@ -451,7 +451,7 @@ fn f64_float_alias() {
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-dev", feature = "gen-wasm"))]
fn f32_float_alias() {
assert_evals_to!(
indoc!(
@ -890,16 +890,20 @@ fn gen_int_neq() {
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn gen_int_less_than() {
assert_evals_to!(
indoc!(
r#"
4 < 5
"#
),
true,
bool
);
fn int_less_than() {
assert_evals_to!("4 < 5", true, bool);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn float_less_than() {
assert_evals_to!("4.0 < 5.0", true, bool);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn float_greater_than() {
assert_evals_to!("5.0 > 4.0", true, bool);
}
#[test]