Fix Num.sqrt, div, mod, and rem

This commit is contained in:
Richard Feldman 2020-06-25 21:47:02 -04:00
parent ee52d52047
commit 44477f98e9
6 changed files with 118 additions and 42 deletions

View file

@ -309,29 +309,35 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
// lowest : Int // lowest : Int
add_type(Symbol::NUM_MIN_INT, int_type()); add_type(Symbol::NUM_MIN_INT, int_type());
// div : Int, Int -> Int // div : Int, Int -> Result Int [ DivByZero ]*
add_type(
Symbol::NUM_DIV_INT,
SolvedType::Func(vec![int_type(), int_type()], Box::new(int_type())),
);
// rem : Int, Int -> Int
add_type(
Symbol::NUM_REM,
SolvedType::Func(vec![int_type(), int_type()], Box::new(int_type())),
);
// mod : Int, Int -> Result Int [ DivByZero ]*
let div_by_zero = SolvedType::TagUnion( let div_by_zero = SolvedType::TagUnion(
vec![(TagName::Global("DivByZero".into()), vec![])], vec![(TagName::Global("DivByZero".into()), vec![])],
Box::new(SolvedType::Wildcard), Box::new(SolvedType::Wildcard),
); );
add_type(
Symbol::NUM_DIV_INT,
SolvedType::Func(
vec![int_type(), int_type()],
Box::new(result_type(int_type(), div_by_zero.clone())),
),
);
// rem : Int, Int -> Result Int [ DivByZero ]*
add_type(
Symbol::NUM_REM,
SolvedType::Func(
vec![int_type(), int_type()],
Box::new(result_type(int_type(), div_by_zero.clone())),
),
);
// mod : Int, Int -> Result Int [ DivByZero ]*
add_type( add_type(
Symbol::NUM_MOD_INT, Symbol::NUM_MOD_INT,
SolvedType::Func( SolvedType::Func(
vec![int_type(), int_type()], vec![int_type(), int_type()],
Box::new(result_type(flex(TVAR1), div_by_zero)), Box::new(result_type(int_type(), div_by_zero.clone())),
), ),
); );
@ -340,19 +346,33 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
// div : Float, Float -> Float // div : Float, Float -> Float
add_type( add_type(
Symbol::NUM_DIV_FLOAT, Symbol::NUM_DIV_FLOAT,
SolvedType::Func(vec![float_type(), float_type()], Box::new(float_type())), SolvedType::Func(
vec![float_type(), float_type()],
Box::new(result_type(float_type(), div_by_zero.clone())),
),
); );
// mod : Float, Float -> Float // mod : Float, Float -> Result Int [ DivByZero ]*
add_type( add_type(
Symbol::NUM_MOD_FLOAT, Symbol::NUM_MOD_FLOAT,
SolvedType::Func(vec![float_type(), float_type()], Box::new(float_type())), SolvedType::Func(
vec![float_type(), float_type()],
Box::new(result_type(float_type(), div_by_zero)),
),
); );
// sqrt : Float -> Float // sqrt : Float -> Float
let sqrt_of_negative = SolvedType::TagUnion(
vec![(TagName::Global("SqrtOfNegative".into()), vec![])],
Box::new(SolvedType::Wildcard),
);
add_type( add_type(
Symbol::NUM_SQRT, Symbol::NUM_SQRT,
SolvedType::Func(vec![float_type()], Box::new(float_type())), SolvedType::Func(
vec![float_type()],
Box::new(result_type(float_type(), sqrt_of_negative)),
),
); );
// round : Float -> Int // round : Float -> Int

View file

@ -390,10 +390,52 @@ fn num_to_float(symbol: Symbol, var_store: &mut VarStore) -> Def {
/// Num.sqrt : Float -> Result Float [ SqrtOfNegative ]* /// Num.sqrt : Float -> Result Float [ SqrtOfNegative ]*
fn num_sqrt(symbol: Symbol, var_store: &mut VarStore) -> Def { fn num_sqrt(symbol: Symbol, var_store: &mut VarStore) -> Def {
let body = RunLowLevel { let bool_var = var_store.fresh();
op: LowLevel::NumSqrt, let num_var = var_store.fresh();
args: vec![(var_store.fresh(), Var(Symbol::ARG_1))], let unbound_zero_var = var_store.fresh();
ret_var: var_store.fresh(), let branch_var = var_store.fresh();
let body = If {
branch_var,
cond_var: bool_var,
branches: vec![(
// if-condition
no_region(
// Num.neq denominator 0
RunLowLevel {
op: LowLevel::NotEq,
args: vec![
(num_var, Var(Symbol::ARG_1)),
(num_var, Float(unbound_zero_var, 0.0)),
],
ret_var: bool_var,
},
),
// denominator was not zero
no_region(
// Ok (Float.#divUnchecked numerator denominator)
tag(
"Ok",
vec![
// Num.#divUnchecked numerator denominator
RunLowLevel {
op: LowLevel::NumSqrtUnchecked,
args: vec![(num_var, Var(Symbol::ARG_1))],
ret_var: num_var,
},
],
var_store,
),
),
)],
final_else: Box::new(
// denominator was zero
no_region(tag(
"Err",
vec![tag("DivByZero", Vec::new(), var_store)],
var_store,
)),
),
}; };
defn(symbol, vec![Symbol::ARG_1], var_store, body) defn(symbol, vec![Symbol::ARG_1], var_store, body)
@ -686,11 +728,11 @@ fn num_div_float(symbol: Symbol, var_store: &mut VarStore) -> Def {
), ),
// denominator was not zero // denominator was not zero
no_region( no_region(
// Ok (Float.#divUnsafe numerator denominator) // Ok (Float.#divUnchecked numerator denominator)
tag( tag(
"Ok", "Ok",
vec![ vec![
// Num.#divUnsafe numerator denominator // Num.#divUnchecked numerator denominator
RunLowLevel { RunLowLevel {
op: LowLevel::NumDivUnchecked, op: LowLevel::NumDivUnchecked,
args: vec![ args: vec![
@ -720,9 +762,13 @@ fn num_div_float(symbol: Symbol, var_store: &mut VarStore) -> Def {
/// Num.div : Int, Int -> Result Int [ DivByZero ]* /// Num.div : Int, Int -> Result Int [ DivByZero ]*
fn num_div_int(symbol: Symbol, var_store: &mut VarStore) -> Def { fn num_div_int(symbol: Symbol, var_store: &mut VarStore) -> Def {
let bool_var = var_store.fresh(); let bool_var = var_store.fresh();
let num_var = var_store.fresh();
let unbound_zero_var = var_store.fresh();
let branch_var = var_store.fresh();
let body = If { let body = If {
branch_var: var_store.fresh(), branch_var,
cond_var: var_store.fresh(), cond_var: bool_var,
branches: vec![( branches: vec![(
// if-condition // if-condition
no_region( no_region(
@ -730,26 +776,26 @@ fn num_div_int(symbol: Symbol, var_store: &mut VarStore) -> Def {
RunLowLevel { RunLowLevel {
op: LowLevel::NotEq, op: LowLevel::NotEq,
args: vec![ args: vec![
(bool_var, Var(Symbol::ARG_1)), (num_var, Var(Symbol::ARG_1)),
(bool_var, Int(var_store.fresh(), 0)), (num_var, Int(unbound_zero_var, 0)),
], ],
ret_var: var_store.fresh(), ret_var: bool_var,
}, },
), ),
// denominator was not zero // denominator was not zero
no_region( no_region(
// Ok (Int.#divUnsafe numerator denominator) // Ok (Int.#divUnchecked numerator denominator)
tag( tag(
"Ok", "Ok",
vec![ vec![
// Num.#divUnsafe numerator denominator // Num.#divUnchecked numerator denominator
RunLowLevel { RunLowLevel {
op: LowLevel::NumDivUnchecked, op: LowLevel::NumDivUnchecked,
args: vec![ args: vec![
(var_store.fresh(), Var(Symbol::ARG_1)), (num_var, Var(Symbol::ARG_1)),
(var_store.fresh(), Var(Symbol::ARG_2)), (num_var, Var(Symbol::ARG_2)),
], ],
ret_var: var_store.fresh(), ret_var: num_var,
}, },
], ],
var_store, var_store,

View file

@ -1366,7 +1366,7 @@ fn run_low_level<'a, 'ctx, 'env>(
"cast_collection", "cast_collection",
) )
} }
NumAbs | NumNeg | NumRound | NumSqrt | NumSin | NumCos | NumToFloat => { NumAbs | NumNeg | NumRound | NumSqrtUnchecked | NumSin | NumCos | NumToFloat => {
debug_assert_eq!(args.len(), 1); debug_assert_eq!(args.len(), 1);
let arg = build_expr(env, layout_ids, scope, parent, &args[0].0); let arg = build_expr(env, layout_ids, scope, parent, &args[0].0);
@ -1683,7 +1683,7 @@ fn build_float_unary_op<'a, 'ctx, 'env>(
match op { match op {
NumNeg => bd.build_float_neg(arg, "negate_float").into(), NumNeg => bd.build_float_neg(arg, "negate_float").into(),
NumAbs => call_intrinsic(LLVM_FABS_F64, env, &[(arg.into(), arg_layout)]), NumAbs => call_intrinsic(LLVM_FABS_F64, env, &[(arg.into(), arg_layout)]),
NumSqrt => call_intrinsic(LLVM_SQRT_F64, env, &[(arg.into(), arg_layout)]), NumSqrtUnchecked => call_intrinsic(LLVM_SQRT_F64, env, &[(arg.into(), arg_layout)]),
NumRound => call_intrinsic(LLVM_LROUND_I64_F64, env, &[(arg.into(), arg_layout)]), NumRound => call_intrinsic(LLVM_LROUND_I64_F64, env, &[(arg.into(), arg_layout)]),
NumSin => call_intrinsic(LLVM_SIN_F64, env, &[(arg.into(), arg_layout)]), NumSin => call_intrinsic(LLVM_SIN_F64, env, &[(arg.into(), arg_layout)]),
NumCos => call_intrinsic(LLVM_COS_F64, env, &[(arg.into(), arg_layout)]), NumCos => call_intrinsic(LLVM_COS_F64, env, &[(arg.into(), arg_layout)]),

View file

@ -29,7 +29,17 @@ mod gen_builtins {
#[test] #[test]
fn f64_sqrt() { fn f64_sqrt() {
assert_evals_to!("Num.sqrt 144", 12.0, f64); assert_evals_to!(
indoc!(
r#"
when Num.sqrt 144 is
Ok val -> val
Err _ -> -1
"#
),
12.0,
f64
);
} }
#[test] #[test]
@ -234,7 +244,7 @@ mod gen_builtins {
r#" r#"
when Num.rem 8 3 is when Num.rem 8 3 is
Ok val -> val Ok val -> val
_ -> -1 Err _ -> -1
"# "#
), ),
2, 2,
@ -249,7 +259,7 @@ mod gen_builtins {
r#" r#"
when Num.rem 8 0 is when Num.rem 8 0 is
Err DivByZero -> 4 Err DivByZero -> 4
_ -> -23 Ok _ -> -23
"# "#
), ),
4, 4,

View file

@ -21,7 +21,7 @@ pub enum LowLevel {
NumNeg, NumNeg,
NumSin, NumSin,
NumCos, NumCos,
NumSqrt, NumSqrtUnchecked,
NumRound, NumRound,
NumToFloat, NumToFloat,
Eq, Eq,

View file

@ -756,8 +756,8 @@ fn annotate_low_level_usage(
} }
ListSingle | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte | NumAbs ListSingle | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte | NumAbs
| NumNeg | NumDivUnchecked | NumRemUnchecked | NumSqrt | NumRound | NumSin | NumCos | NumNeg | NumDivUnchecked | NumRemUnchecked | NumSqrtUnchecked | NumRound | NumSin
| Eq | NotEq | And | Or | Not | NumToFloat => { | NumCos | Eq | NotEq | And | Or | Not | NumToFloat => {
for (_, arg) in args { for (_, arg) in args {
annotate_usage(&arg, usage); annotate_usage(&arg, usage);
} }