diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index 05c60dbfad..2ddc07a0a0 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -54,6 +54,13 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::LIST_GET => list_get, Symbol::LIST_SET => list_set, Symbol::LIST_FIRST => list_first, + Symbol::NUM_ADD => num_add, + Symbol::NUM_SUB => num_sub, + Symbol::NUM_MUL => num_mul, + Symbol::NUM_GT => num_gt, + Symbol::NUM_GTE => num_gte, + Symbol::NUM_LT => num_lt, + Symbol::NUM_LTE => num_lte, Symbol::INT_DIV => int_div, Symbol::INT_ABS => int_abs, Symbol::INT_REM => int_rem, @@ -166,6 +173,56 @@ fn bool_and(symbol: Symbol, var_store: &mut VarStore) -> Def { ) } +fn num_binop(symbol: Symbol, var_store: &mut VarStore, op: LowLevel) -> Def { + use crate::expr::Expr::*; + + let body = RunLowLevel { + op, + args: vec![ + (var_store.fresh(), Var(Symbol::ARG_1)), + (var_store.fresh(), Var(Symbol::ARG_2)), + ], + ret_var: var_store.fresh(), + }; + + defn(symbol, vec![Symbol::ARG_1, Symbol::ARG_2], var_store, body) +} + +/// Num.add : Num a, Num a -> Num a +fn num_add(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_binop(symbol, var_store, LowLevel::NumAdd) +} + +/// Num.sub : Num a, Num a -> Num a +fn num_sub(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_binop(symbol, var_store, LowLevel::NumSub) +} + +/// Num.mul : Num a, Num a -> Num a +fn num_mul(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_binop(symbol, var_store, LowLevel::NumMul) +} + +/// Num.gt : Num a, Num a -> Num a +fn num_gt(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_binop(symbol, var_store, LowLevel::NumGt) +} + +/// Num.gte : Num a, Num a -> Num a +fn num_gte(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_binop(symbol, var_store, LowLevel::NumGte) +} + +/// Num.lt : Num a, Num a -> Num a +fn num_lt(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_binop(symbol, var_store, LowLevel::NumLt) +} + +/// Num.lte : Num a, Num a -> Num a +fn num_lte(symbol: Symbol, var_store: &mut VarStore) -> Def { + num_binop(symbol, var_store, LowLevel::NumLte) +} + /// Float.tan : Float -> Float fn float_tan(symbol: Symbol, var_store: &mut VarStore) -> Def { use crate::expr::Expr::*; diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 95a09d6106..16ea6a124a 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -12,9 +12,9 @@ use inkwell::module::{Linkage, Module}; use inkwell::passes::{PassManager, PassManagerBuilder}; use inkwell::types::{BasicTypeEnum, FunctionType, IntType, PointerType, StructType}; use inkwell::values::BasicValueEnum::{self, *}; -use inkwell::values::{FunctionValue, IntValue, PointerValue, StructValue}; +use inkwell::values::{FloatValue, FunctionValue, IntValue, PointerValue, StructValue}; use inkwell::AddressSpace; -use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; +use inkwell::{IntPredicate, OptimizationLevel}; use roc_collections::all::ImMap; use roc_module::low_level::LowLevel; use roc_module::symbol::{Interns, Symbol}; @@ -1021,158 +1021,7 @@ fn call_with_args<'a, 'ctx, 'env>( args: &[(BasicValueEnum<'ctx>, &'a Layout<'a>)], ) -> BasicValueEnum<'ctx> { match symbol { - Symbol::INT_ADD | Symbol::NUM_ADD => { - debug_assert!(args.len() == 2); - - let int_val = env.builder.build_int_add( - args[0].0.into_int_value(), - args[1].0.into_int_value(), - "add_i64", - ); - - BasicValueEnum::IntValue(int_val) - } - Symbol::FLOAT_ADD => { - debug_assert!(args.len() == 2); - - let float_val = env.builder.build_float_add( - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "add_f64", - ); - - BasicValueEnum::FloatValue(float_val) - } - Symbol::INT_SUB | Symbol::NUM_SUB => { - debug_assert!(args.len() == 2); - - let int_val = env.builder.build_int_sub( - args[0].0.into_int_value(), - args[1].0.into_int_value(), - "sub_i64", - ); - - BasicValueEnum::IntValue(int_val) - } - Symbol::FLOAT_DIV => { - debug_assert!(args.len() == 2); - - let float_val = env.builder.build_float_div( - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "div_f64", - ); - - BasicValueEnum::FloatValue(float_val) - } - Symbol::FLOAT_SUB => { - debug_assert!(args.len() == 2); - - let float_val = env.builder.build_float_sub( - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "sub_f64", - ); - - BasicValueEnum::FloatValue(float_val) - } Symbol::FLOAT_ABS => call_intrinsic(LLVM_FABS_F64, env, args), - Symbol::INT_GTE | Symbol::NUM_GTE => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_int_compare( - IntPredicate::SGE, - args[0].0.into_int_value(), - args[1].0.into_int_value(), - "gte_i64", - ); - - BasicValueEnum::IntValue(bool_val) - } - Symbol::FLOAT_GTE => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_float_compare( - FloatPredicate::OGE, - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "gte_F64", - ); - - BasicValueEnum::IntValue(bool_val) - } - Symbol::INT_GT | Symbol::NUM_GT => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_int_compare( - IntPredicate::SGT, - args[0].0.into_int_value(), - args[1].0.into_int_value(), - "gt_i64", - ); - - BasicValueEnum::IntValue(bool_val) - } - Symbol::FLOAT_GT => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_float_compare( - FloatPredicate::OGT, - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "gt_f64", - ); - - BasicValueEnum::IntValue(bool_val) - } - Symbol::INT_LTE | Symbol::NUM_LTE => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_int_compare( - IntPredicate::SLE, - args[0].0.into_int_value(), - args[1].0.into_int_value(), - "lte_i64", - ); - - BasicValueEnum::IntValue(bool_val) - } - Symbol::FLOAT_LTE => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_float_compare( - FloatPredicate::OLE, - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "lte_f64", - ); - - BasicValueEnum::IntValue(bool_val) - } - Symbol::INT_LT | Symbol::NUM_LT => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_int_compare( - IntPredicate::SLT, - args[0].0.into_int_value(), - args[1].0.into_int_value(), - "lt_i64", - ); - - BasicValueEnum::IntValue(bool_val) - } - Symbol::FLOAT_LT => { - debug_assert!(args.len() == 2); - - let bool_val = env.builder.build_float_compare( - FloatPredicate::OLT, - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "lt_f64", - ); - - BasicValueEnum::IntValue(bool_val) - } Symbol::FLOAT_SIN => call_intrinsic(LLVM_SIN_F64, env, args), Symbol::FLOAT_COS => call_intrinsic(LLVM_COS_F64, env, args), Symbol::NUM_MUL => { @@ -1247,18 +1096,6 @@ fn call_with_args<'a, 'ctx, 'env>( .left() .unwrap_or_else(|| panic!("LLVM error: Invalid call for builtin {:?}", symbol)) } - Symbol::FLOAT_EQ => { - debug_assert!(args.len() == 2); - - let int_val = env.builder.build_float_compare( - FloatPredicate::OEQ, - args[0].0.into_float_value(), - args[1].0.into_float_value(), - "cmp_f64", - ); - - BasicValueEnum::IntValue(int_val) - } Symbol::FLOAT_SQRT => call_intrinsic(LLVM_SQRT_F64, env, args), Symbol::FLOAT_ROUND => call_intrinsic(LLVM_LROUND_I64_F64, env, args), Symbol::LIST_SET => list_set(parent, args, env, InPlace::Clone), @@ -1615,6 +1452,43 @@ fn run_low_level<'a, 'ctx, 'env>( BasicValueEnum::IntValue(load_list_len(env.builder, arg.into_struct_value())) } + NumAdd | NumSub | NumMul | NumLt | NumLte | NumGt | NumGte => { + debug_assert_eq!(args.len(), 2); + + let lhs_arg = build_expr(env, layout_ids, scope, parent, &args[0].0); + let lhs_layout = &args[0].1; + let rhs_arg = build_expr(env, layout_ids, scope, parent, &args[1].0); + let rhs_layout = &args[1].1; + + match (lhs_layout, rhs_layout) { + (Layout::Builtin(lhs_builtin), Layout::Builtin(rhs_builtin)) + if lhs_builtin == rhs_builtin => + { + use roc_mono::layout::Builtin::*; + + match lhs_builtin { + Int128 | Int64 | Int32 | Int16 | Int8 => build_int_binop( + env.builder, + lhs_arg.into_int_value(), + rhs_arg.into_int_value(), + op, + ), + Float64 | Float32 => build_float_binop( + env.builder, + lhs_arg.into_float_value(), + rhs_arg.into_float_value(), + op, + ), + _ => { + unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid builtin layout: ({:?})", op, lhs_layout); + } + } + } + _ => { + unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid layouts. The 2 layouts were: ({:?}) and ({:?})", op, lhs_layout, rhs_layout); + } + } + } Eq => { debug_assert_eq!(args.len(), 2); @@ -1710,3 +1584,49 @@ fn run_low_level<'a, 'ctx, 'env>( } } } + +fn build_int_binop<'ctx>( + bd: &Builder<'ctx>, + lhs: IntValue<'ctx>, + rhs: IntValue<'ctx>, + op: LowLevel, +) -> BasicValueEnum<'ctx> { + use inkwell::IntPredicate::*; + use roc_module::low_level::LowLevel::*; + + match op { + NumAdd => bd.build_int_add(lhs, rhs, "add_int").into(), + NumSub => bd.build_int_sub(lhs, rhs, "sub_int").into(), + NumMul => bd.build_int_mul(lhs, rhs, "mul_int").into(), + NumGt => bd.build_int_compare(SGT, lhs, rhs, "int_gt").into(), + NumGte => bd.build_int_compare(SGE, lhs, rhs, "int_gte").into(), + NumLt => bd.build_int_compare(SLT, lhs, rhs, "int_lt").into(), + NumLte => bd.build_int_compare(SLE, lhs, rhs, "int_lte").into(), + _ => { + unreachable!("Unrecognized int binary operation: {:?}", op); + } + } +} + +fn build_float_binop<'ctx>( + bd: &Builder<'ctx>, + lhs: FloatValue<'ctx>, + rhs: FloatValue<'ctx>, + op: LowLevel, +) -> BasicValueEnum<'ctx> { + use inkwell::FloatPredicate::*; + use roc_module::low_level::LowLevel::*; + + match op { + NumAdd => bd.build_float_add(lhs, rhs, "add_float").into(), + NumSub => bd.build_float_sub(lhs, rhs, "sub_float").into(), + NumMul => bd.build_float_mul(lhs, rhs, "mul_float").into(), + NumGt => bd.build_float_compare(OGT, lhs, rhs, "float_gt").into(), + NumGte => bd.build_float_compare(OGE, lhs, rhs, "float_gte").into(), + NumLt => bd.build_float_compare(OLT, lhs, rhs, "float_lt").into(), + NumLte => bd.build_float_compare(OLE, lhs, rhs, "float_lte").into(), + _ => { + unreachable!("Unrecognized int binary operation: {:?}", op); + } + } +} diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 4017f38069..0119572017 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -6,6 +6,13 @@ pub enum LowLevel { ListLen, ListGetUnsafe, ListSetUnsafe, + NumAdd, + NumSub, + NumMul, + NumGt, + NumGte, + NumLt, + NumLte, Eq, NotEq, And, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 636636a251..c7d74c3a5b 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -575,6 +575,14 @@ define_builtins! { 0 ATTR: "#Attr" => { 0 UNDERSCORE: "_" // the _ used in pattern matches. This is Symbol 0. 1 ATTR_ATTR: "Attr" // the #Attr.Attr type alias, used in uniqueness types. + 2 ARG_1: "#arg1" + 3 ARG_2: "#arg2" + 4 ARG_3: "#arg3" + 5 ARG_4: "#arg4" + 6 ARG_5: "#arg5" + 7 ARG_6: "#arg6" + 8 ARG_7: "#arg7" + 9 ARG_8: "#arg8" } 1 NUM: "Num" => { 0 NUM_NUM: "Num" imported // the Num.Num type alias diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index 19bc2a5c47..b8956fd16f 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -358,34 +358,6 @@ fn num_argument_to_int_or_float(subs: &Subs, var: Variable) -> IntOrFloat { } } -/// Given a `Num a`, determines whether it's an int or a float -fn num_to_int_or_float(subs: &Subs, var: Variable) -> IntOrFloat { - match subs.get_without_compacting(var).content { - Content::Alias(Symbol::NUM_NUM, args, _) => { - debug_assert!(args.len() == 1); - - num_argument_to_int_or_float(subs, args[0].1) - } - - Content::Alias(Symbol::INT_INT, _, _) => IntOrFloat::IntType, - Content::Alias(Symbol::FLOAT_FLOAT, _, _) => IntOrFloat::FloatType, - - Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, attr_args)) => { - debug_assert!(attr_args.len() == 2); - - // Recurse on the second argument - num_to_int_or_float(subs, attr_args[1]) - } - - other => { - panic!( - "Input variable is not a Num, but {:?} is a {:?}", - var, other - ); - } - } -} - /// turn record/tag patterns into a when expression, e.g. /// /// foo = \{ x } -> body @@ -586,8 +558,7 @@ fn from_can<'a>( Expr::Load(proc_name) => { // Some functions can potentially mutate in-place. // If we have one of those, switch to the in-place version if appropriate. - match specialize_builtin_functions(env, proc_name, loc_args.as_slice(), ret_var) - { + match proc_name { Symbol::LIST_SET => { let subs = &env.subs; // The first arg is the one with the List in it. @@ -723,7 +694,7 @@ fn from_can<'a>( }; expr = Expr::Store( - bumpalo::vec![in arena; (branch_symbol, Layout::Builtin(Builtin::Int8), cond)] + bumpalo::vec![in arena; (branch_symbol, Layout::Builtin(Builtin::Int1), cond)] .into_bump_slice(), env.arena.alloc(cond_expr), ); @@ -1839,52 +1810,3 @@ fn from_can_record_destruct<'a>( }, } } - -fn specialize_builtin_functions<'a>( - env: &mut Env<'a, '_>, - symbol: Symbol, - loc_args: &[(Variable, Located)], - ret_var: Variable, -) -> Symbol { - use IntOrFloat::*; - - if !symbol.is_builtin() { - // return unchanged - symbol - } else { - if true { - todo!( - "replace specialize_builtin_functions({:?}) with a LowLevel op", - symbol - ); - } - - match symbol { - Symbol::NUM_ADD => match num_to_int_or_float(env.subs, ret_var) { - FloatType => Symbol::FLOAT_ADD, - IntType => Symbol::INT_ADD, - }, - Symbol::NUM_SUB => match num_to_int_or_float(env.subs, ret_var) { - FloatType => Symbol::FLOAT_SUB, - IntType => Symbol::INT_SUB, - }, - Symbol::NUM_LTE => match num_to_int_or_float(env.subs, loc_args[0].0) { - FloatType => Symbol::FLOAT_LTE, - IntType => Symbol::INT_LTE, - }, - Symbol::NUM_LT => match num_to_int_or_float(env.subs, loc_args[0].0) { - FloatType => Symbol::FLOAT_LT, - IntType => Symbol::INT_LT, - }, - Symbol::NUM_GTE => match num_to_int_or_float(env.subs, loc_args[0].0) { - FloatType => Symbol::FLOAT_GTE, - IntType => Symbol::INT_GTE, - }, - Symbol::NUM_GT => match num_to_int_or_float(env.subs, loc_args[0].0) { - FloatType => Symbol::FLOAT_GT, - IntType => Symbol::INT_GT, - }, - _ => symbol, - } - } -} diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 18298ecccc..8d2876c780 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -85,20 +85,6 @@ mod test_mono { compiles_to("0.5", Float(0.5)); } - #[test] - fn apply_identity() { - compiles_to( - indoc!( - r#" - identity = \a -> a - - identity 5 - "# - ), - Int(5), - ); - } - #[test] fn float_addition() { compiles_to(