implement Num.compare

This commit is contained in:
Folkert 2020-09-08 19:40:18 +02:00
parent 4c995b12a6
commit 1b42831973
8 changed files with 225 additions and 2 deletions

View file

@ -265,6 +265,15 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
), ),
); );
// compare : Num a, Num a -> [ LT, EQ, GT ]
add_type(
Symbol::NUM_COMPARE,
SolvedType::Func(
vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))],
Box::new(ordering_type()),
),
);
// toFloat : Num a -> Float // toFloat : Num a -> Float
add_type( add_type(
Symbol::NUM_TO_FLOAT, Symbol::NUM_TO_FLOAT,
@ -722,6 +731,19 @@ fn bool_type() -> SolvedType {
SolvedType::Apply(Symbol::BOOL_BOOL, Vec::new()) SolvedType::Apply(Symbol::BOOL_BOOL, Vec::new())
} }
#[inline(always)]
fn ordering_type() -> SolvedType {
// [ LT, EQ, GT ]
SolvedType::TagUnion(
vec![
(TagName::Global("GT".into()), vec![]),
(TagName::Global("EQ".into()), vec![]),
(TagName::Global("LT".into()), vec![]),
],
Box::new(SolvedType::EmptyTagUnion),
)
}
#[inline(always)] #[inline(always)]
fn str_type() -> SolvedType { fn str_type() -> SolvedType {
SolvedType::Apply(Symbol::STR_STR, Vec::new()) SolvedType::Apply(Symbol::STR_STR, Vec::new())

View file

@ -317,6 +317,12 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
// isGte or (>=) : Num a, Num a -> Bool // isGte or (>=) : Num a, Num a -> Bool
add_num_comparison(Symbol::NUM_GTE); add_num_comparison(Symbol::NUM_GTE);
// compare : Num a, Num a -> [ LT, EQ, GT ]
add_type(Symbol::NUM_COMPARE, {
let_tvars! { u, v, w, num };
unique_function(vec![num_type(u, num), num_type(v, num)], ordering_type(w))
});
// toFloat : Num a -> Float // toFloat : Num a -> Float
add_type(Symbol::NUM_TO_FLOAT, { add_type(Symbol::NUM_TO_FLOAT, {
let_tvars! { star1, star2, a }; let_tvars! { star1, star2, a };
@ -1205,3 +1211,22 @@ fn map_type(u: VarId, key: VarId, value: VarId) -> SolvedType {
], ],
) )
} }
#[inline(always)]
fn ordering_type(u: VarId) -> SolvedType {
// [ LT, EQ, GT ]
SolvedType::Apply(
Symbol::ATTR_ATTR,
vec![
flex(u),
SolvedType::TagUnion(
vec![
(TagName::Global("GT".into()), vec![]),
(TagName::Global("EQ".into()), vec![]),
(TagName::Global("LT".into()), vec![]),
],
Box::new(SolvedType::EmptyTagUnion),
),
],
)
}

View file

@ -73,6 +73,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap<Symbol, Def> {
Symbol::NUM_GTE => num_gte, Symbol::NUM_GTE => num_gte,
Symbol::NUM_LT => num_lt, Symbol::NUM_LT => num_lt,
Symbol::NUM_LTE => num_lte, Symbol::NUM_LTE => num_lte,
Symbol::NUM_COMPARE => num_compare,
Symbol::NUM_SIN => num_sin, Symbol::NUM_SIN => num_sin,
Symbol::NUM_COS => num_cos, Symbol::NUM_COS => num_cos,
Symbol::NUM_TAN => num_tan, Symbol::NUM_TAN => num_tan,
@ -262,6 +263,11 @@ fn num_lte(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_bool_binop(symbol, var_store, LowLevel::NumLte) num_bool_binop(symbol, var_store, LowLevel::NumLte)
} }
/// Num.compare : Num a, Num a -> [ LT, EQ, GT ]
fn num_compare(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_bool_binop(symbol, var_store, LowLevel::NumCompare)
}
/// Num.sin : Float -> Float /// Num.sin : Float -> Float
fn num_sin(symbol: Symbol, var_store: &mut VarStore) -> Def { fn num_sin(symbol: Symbol, var_store: &mut VarStore) -> Def {
let float_var = var_store.fresh(); let float_var = var_store.fresh();

View file

@ -185,7 +185,7 @@ pub fn construct_optimization_passes<'a>(
} }
OptLevel::Optimize => { OptLevel::Optimize => {
// this threshold seems to do what we want // this threshold seems to do what we want
pmb.set_inliner_with_threshold(2); pmb.set_inliner_with_threshold(275);
// TODO figure out which of these actually help // TODO figure out which of these actually help
@ -1650,6 +1650,88 @@ fn run_low_level<'a, 'ctx, 'env>(
} }
} }
} }
NumCompare => {
use inkwell::FloatPredicate;
use inkwell::IntPredicate;
debug_assert_eq!(args.len(), 2);
let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]);
let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]);
match (lhs_layout, rhs_layout) {
(Layout::Builtin(lhs_builtin), Layout::Builtin(rhs_builtin))
if lhs_builtin == rhs_builtin =>
{
use roc_mono::layout::Builtin::*;
let tag_eq = env.context.i8_type().const_int(0 as u64, false);
let tag_gt = env.context.i8_type().const_int(1 as u64, false);
let tag_lt = env.context.i8_type().const_int(2 as u64, false);
match lhs_builtin {
Int128 | Int64 | Int32 | Int16 | Int8 => {
let are_equal = env.builder.build_int_compare(
IntPredicate::EQ,
lhs_arg.into_int_value(),
rhs_arg.into_int_value(),
"int_eq",
);
let is_less_than = env.builder.build_int_compare(
IntPredicate::SLT,
lhs_arg.into_int_value(),
rhs_arg.into_int_value(),
"int_compare",
);
let step1 =
env.builder
.build_select(is_less_than, tag_lt, tag_gt, "lt_or_gt");
env.builder.build_select(
are_equal,
tag_eq,
step1.into_int_value(),
"lt_or_gt",
)
}
Float128 | Float64 | Float32 | Float16 => {
let are_equal = env.builder.build_float_compare(
FloatPredicate::OEQ,
lhs_arg.into_float_value(),
rhs_arg.into_float_value(),
"float_eq",
);
let is_less_than = env.builder.build_float_compare(
FloatPredicate::OLT,
lhs_arg.into_float_value(),
rhs_arg.into_float_value(),
"float_compare",
);
let step1 =
env.builder
.build_select(is_less_than, tag_lt, tag_gt, "lt_or_gt");
env.builder.build_select(
are_equal,
tag_eq,
step1.into_int_value(),
"lt_or_gt",
)
}
_ => {
unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid builtin layout: ({:?})", op, lhs_layout);
}
}
}
_ => {
unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid layouts. The 2 layouts were: ({:?}) and ({:?})", op, lhs_layout, rhs_layout);
}
}
}
NumAdd | NumSub | NumMul | NumLt | NumLte | NumGt | NumGte | NumRemUnchecked NumAdd | NumSub | NumMul | NumLt | NumLte | NumGt | NumGte | NumRemUnchecked
| NumDivUnchecked => { | NumDivUnchecked => {
debug_assert_eq!(args.len(), 2); debug_assert_eq!(args.len(), 2);

View file

@ -581,4 +581,88 @@ mod gen_num {
fn float_to_float() { fn float_to_float() {
assert_evals_to!("Num.toFloat 0.5", 0.5, f64); assert_evals_to!("Num.toFloat 0.5", 0.5, f64);
} }
#[test]
fn int_compare() {
assert_evals_to!(
indoc!(
r#"
when Num.compare 0 1 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
0,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 1 1 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
1,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 1 0 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
2,
i64
);
}
#[test]
fn float_compare() {
assert_evals_to!(
indoc!(
r#"
when Num.compare 0 3.14 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
0,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 3.14 3.14 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
1,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 3.14 0 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
2,
i64
);
}
} }

View file

@ -25,6 +25,7 @@ pub enum LowLevel {
NumGte, NumGte,
NumLt, NumLt,
NumLte, NumLte,
NumCompare,
NumDivUnchecked, NumDivUnchecked,
NumRemUnchecked, NumRemUnchecked,
NumAbs, NumAbs,

View file

@ -639,6 +639,7 @@ define_builtins! {
34 NUM_MOD_FLOAT: "modFloat" 34 NUM_MOD_FLOAT: "modFloat"
35 NUM_SQRT: "sqrt" 35 NUM_SQRT: "sqrt"
36 NUM_ROUND: "round" 36 NUM_ROUND: "round"
37 NUM_COMPARE: "compare"
} }
2 BOOL: "Bool" => { 2 BOOL: "Bool" => {
0 BOOL_BOOL: "Bool" imported // the Bool.Bool type alias 0 BOOL_BOOL: "Bool" imported // the Bool.Bool type alias

View file

@ -522,7 +522,9 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
ListWalkRight => arena.alloc_slice_copy(&[borrowed, irrelevant, owned]), ListWalkRight => arena.alloc_slice_copy(&[borrowed, irrelevant, owned]),
Eq | NotEq | And | Or | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte Eq | NotEq | And | Or | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte
| NumDivUnchecked | NumRemUnchecked => arena.alloc_slice_copy(&[irrelevant, irrelevant]), | NumCompare | NumDivUnchecked | NumRemUnchecked => {
arena.alloc_slice_copy(&[irrelevant, irrelevant])
}
NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumToFloat | Not => { NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumToFloat | Not => {
arena.alloc_slice_copy(&[irrelevant]) arena.alloc_slice_copy(&[irrelevant])