From 1b42831973d9c47f54c8e9da903dfac6672e46c0 Mon Sep 17 00:00:00 2001 From: Folkert Date: Tue, 8 Sep 2020 19:40:18 +0200 Subject: [PATCH] implement Num.compare --- compiler/builtins/src/std.rs | 22 +++++++++ compiler/builtins/src/unique.rs | 25 ++++++++++ compiler/can/src/builtins.rs | 6 +++ compiler/gen/src/llvm/build.rs | 84 +++++++++++++++++++++++++++++++- compiler/gen/tests/gen_num.rs | 84 ++++++++++++++++++++++++++++++++ compiler/module/src/low_level.rs | 1 + compiler/module/src/symbol.rs | 1 + compiler/mono/src/borrow.rs | 4 +- 8 files changed, 225 insertions(+), 2 deletions(-) diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 65a89684d8..18b0fc7682 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -265,6 +265,15 @@ pub fn types() -> MutMap { ), ); + // compare : Num a, Num a -> [ LT, EQ, GT ] + add_type( + Symbol::NUM_COMPARE, + SolvedType::Func( + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(ordering_type()), + ), + ); + // toFloat : Num a -> Float add_type( Symbol::NUM_TO_FLOAT, @@ -722,6 +731,19 @@ fn bool_type() -> SolvedType { SolvedType::Apply(Symbol::BOOL_BOOL, Vec::new()) } +#[inline(always)] +fn ordering_type() -> SolvedType { + // [ LT, EQ, GT ] + SolvedType::TagUnion( + vec![ + (TagName::Global("GT".into()), vec![]), + (TagName::Global("EQ".into()), vec![]), + (TagName::Global("LT".into()), vec![]), + ], + Box::new(SolvedType::EmptyTagUnion), + ) +} + #[inline(always)] fn str_type() -> SolvedType { SolvedType::Apply(Symbol::STR_STR, Vec::new()) diff --git a/compiler/builtins/src/unique.rs b/compiler/builtins/src/unique.rs index 949d36f0e6..3712a2e1df 100644 --- a/compiler/builtins/src/unique.rs +++ b/compiler/builtins/src/unique.rs @@ -317,6 +317,12 @@ pub fn types() -> MutMap { // isGte or (>=) : Num a, Num a -> Bool add_num_comparison(Symbol::NUM_GTE); + // compare : Num a, Num a -> [ LT, EQ, GT ] + add_type(Symbol::NUM_COMPARE, { + let_tvars! { u, v, w, num }; + unique_function(vec![num_type(u, num), num_type(v, num)], ordering_type(w)) + }); + // toFloat : Num a -> Float add_type(Symbol::NUM_TO_FLOAT, { let_tvars! { star1, star2, a }; @@ -1205,3 +1211,22 @@ fn map_type(u: VarId, key: VarId, value: VarId) -> SolvedType { ], ) } + +#[inline(always)] +fn ordering_type(u: VarId) -> SolvedType { + // [ LT, EQ, GT ] + SolvedType::Apply( + Symbol::ATTR_ATTR, + vec![ + flex(u), + SolvedType::TagUnion( + vec![ + (TagName::Global("GT".into()), vec![]), + (TagName::Global("EQ".into()), vec![]), + (TagName::Global("LT".into()), vec![]), + ], + Box::new(SolvedType::EmptyTagUnion), + ), + ], + ) +} diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index e2d7b7cea4..ab00f998e6 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -73,6 +73,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::NUM_GTE => num_gte, Symbol::NUM_LT => num_lt, Symbol::NUM_LTE => num_lte, + Symbol::NUM_COMPARE => num_compare, Symbol::NUM_SIN => num_sin, Symbol::NUM_COS => num_cos, Symbol::NUM_TAN => num_tan, @@ -262,6 +263,11 @@ fn num_lte(symbol: Symbol, var_store: &mut VarStore) -> Def { num_bool_binop(symbol, var_store, LowLevel::NumLte) } +/// Num.compare : Num a, Num a -> [ LT, EQ, GT ] +fn num_compare(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_bool_binop(symbol, var_store, LowLevel::NumCompare) +} + /// Num.sin : Float -> Float fn num_sin(symbol: Symbol, var_store: &mut VarStore) -> Def { let float_var = var_store.fresh(); diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 8bce77df74..cf76f50137 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -185,7 +185,7 @@ pub fn construct_optimization_passes<'a>( } OptLevel::Optimize => { // this threshold seems to do what we want - pmb.set_inliner_with_threshold(2); + pmb.set_inliner_with_threshold(275); // TODO figure out which of these actually help @@ -1650,6 +1650,88 @@ fn run_low_level<'a, 'ctx, 'env>( } } } + NumCompare => { + use inkwell::FloatPredicate; + use inkwell::IntPredicate; + + debug_assert_eq!(args.len(), 2); + + let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]); + let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]); + + match (lhs_layout, rhs_layout) { + (Layout::Builtin(lhs_builtin), Layout::Builtin(rhs_builtin)) + if lhs_builtin == rhs_builtin => + { + use roc_mono::layout::Builtin::*; + + let tag_eq = env.context.i8_type().const_int(0 as u64, false); + let tag_gt = env.context.i8_type().const_int(1 as u64, false); + let tag_lt = env.context.i8_type().const_int(2 as u64, false); + + match lhs_builtin { + Int128 | Int64 | Int32 | Int16 | Int8 => { + let are_equal = env.builder.build_int_compare( + IntPredicate::EQ, + lhs_arg.into_int_value(), + rhs_arg.into_int_value(), + "int_eq", + ); + let is_less_than = env.builder.build_int_compare( + IntPredicate::SLT, + lhs_arg.into_int_value(), + rhs_arg.into_int_value(), + "int_compare", + ); + + let step1 = + env.builder + .build_select(is_less_than, tag_lt, tag_gt, "lt_or_gt"); + + env.builder.build_select( + are_equal, + tag_eq, + step1.into_int_value(), + "lt_or_gt", + ) + } + Float128 | Float64 | Float32 | Float16 => { + let are_equal = env.builder.build_float_compare( + FloatPredicate::OEQ, + lhs_arg.into_float_value(), + rhs_arg.into_float_value(), + "float_eq", + ); + let is_less_than = env.builder.build_float_compare( + FloatPredicate::OLT, + lhs_arg.into_float_value(), + rhs_arg.into_float_value(), + "float_compare", + ); + + let step1 = + env.builder + .build_select(is_less_than, tag_lt, tag_gt, "lt_or_gt"); + + env.builder.build_select( + are_equal, + tag_eq, + step1.into_int_value(), + "lt_or_gt", + ) + } + + _ => { + unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid builtin layout: ({:?})", op, lhs_layout); + } + } + } + _ => { + unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid layouts. The 2 layouts were: ({:?}) and ({:?})", op, lhs_layout, rhs_layout); + } + } + } + NumAdd | NumSub | NumMul | NumLt | NumLte | NumGt | NumGte | NumRemUnchecked | NumDivUnchecked => { debug_assert_eq!(args.len(), 2); diff --git a/compiler/gen/tests/gen_num.rs b/compiler/gen/tests/gen_num.rs index e329cc7a92..812314a806 100644 --- a/compiler/gen/tests/gen_num.rs +++ b/compiler/gen/tests/gen_num.rs @@ -581,4 +581,88 @@ mod gen_num { fn float_to_float() { assert_evals_to!("Num.toFloat 0.5", 0.5, f64); } + + #[test] + fn int_compare() { + assert_evals_to!( + indoc!( + r#" + when Num.compare 0 1 is + LT -> 0 + EQ -> 1 + GT -> 2 + "# + ), + 0, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + when Num.compare 1 1 is + LT -> 0 + EQ -> 1 + GT -> 2 + "# + ), + 1, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + when Num.compare 1 0 is + LT -> 0 + EQ -> 1 + GT -> 2 + "# + ), + 2, + i64 + ); + } + + #[test] + fn float_compare() { + assert_evals_to!( + indoc!( + r#" + when Num.compare 0 3.14 is + LT -> 0 + EQ -> 1 + GT -> 2 + "# + ), + 0, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + when Num.compare 3.14 3.14 is + LT -> 0 + EQ -> 1 + GT -> 2 + "# + ), + 1, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + when Num.compare 3.14 0 is + LT -> 0 + EQ -> 1 + GT -> 2 + "# + ), + 2, + i64 + ); + } } diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 632796f4fd..1ca4f18770 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -25,6 +25,7 @@ pub enum LowLevel { NumGte, NumLt, NumLte, + NumCompare, NumDivUnchecked, NumRemUnchecked, NumAbs, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index ecf9511c32..b3711ef4cc 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -639,6 +639,7 @@ define_builtins! { 34 NUM_MOD_FLOAT: "modFloat" 35 NUM_SQRT: "sqrt" 36 NUM_ROUND: "round" + 37 NUM_COMPARE: "compare" } 2 BOOL: "Bool" => { 0 BOOL_BOOL: "Bool" imported // the Bool.Bool type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 596fd14a03..f46fe365c4 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -522,7 +522,9 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ListWalkRight => arena.alloc_slice_copy(&[borrowed, irrelevant, owned]), Eq | NotEq | And | Or | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte - | NumDivUnchecked | NumRemUnchecked => arena.alloc_slice_copy(&[irrelevant, irrelevant]), + | NumCompare | NumDivUnchecked | NumRemUnchecked => { + arena.alloc_slice_copy(&[irrelevant, irrelevant]) + } NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumToFloat | Not => { arena.alloc_slice_copy(&[irrelevant])