diff --git a/compiler/builtins/src/unique.rs b/compiler/builtins/src/unique.rs index 8f33fbdad3..af4aa859b1 100644 --- a/compiler/builtins/src/unique.rs +++ b/compiler/builtins/src/unique.rs @@ -44,7 +44,7 @@ const TVAR3: VarId = VarId::from_u32(3); const FUVAR: VarId = VarId::from_u32(1000); const UVAR1: VarId = VarId::from_u32(1001); const UVAR2: VarId = VarId::from_u32(1002); -const UVAR3: VarId = VarId::from_u32(1003); +// const UVAR3: VarId = VarId::from_u32(1003); const UVAR4: VarId = VarId::from_u32(1004); const UVAR5: VarId = VarId::from_u32(1005); const UVAR6: VarId = VarId::from_u32(1006); @@ -297,10 +297,8 @@ pub fn types() -> MutMap { // , Attr v (Num (Attr v num)) // -> Attr w (Num (Attr w num)) add_type(Symbol::NUM_ADD, { - unique_function( - vec![num_type(UVAR1, TVAR1), num_type(UVAR2, TVAR1)], - num_type(UVAR3, TVAR1), - ) + let_tvars! { u, v, w, num }; + unique_function(vec![num_type(u, num), num_type(v, num)], num_type(w, num)) }); // sub or (-) : Num a, Num a -> Num a @@ -310,61 +308,41 @@ pub fn types() -> MutMap { }); // mul or (*) : Num a, Num a -> Num a - add_type( - Symbol::NUM_MUL, - unique_function( - vec![num_type(UVAR1, TVAR1), num_type(UVAR2, TVAR1)], - num_type(UVAR3, TVAR1), - ), - ); + add_type(Symbol::NUM_MUL, { + let_tvars! { u, v, w, num }; + unique_function(vec![num_type(u, num), num_type(v, num)], num_type(w, num)) + }); // abs : Num a -> Num a - add_type( - Symbol::NUM_ABS, - unique_function(vec![num_type(UVAR1, TVAR1)], num_type(UVAR2, TVAR1)), - ); + add_type(Symbol::NUM_ABS, { + let_tvars! { u, v, num }; + unique_function(vec![num_type(u, num)], num_type(v, num)) + }); // neg : Num a -> Num a - add_type( - Symbol::NUM_NEG, - unique_function(vec![num_type(UVAR1, TVAR1)], num_type(UVAR2, TVAR1)), - ); + add_type(Symbol::NUM_NEG, { + let_tvars! { u, v, num }; + unique_function(vec![num_type(u, num)], num_type(v, num)) + }); + + let mut add_num_comparison = |symbol| { + add_type(symbol, { + let_tvars! { u, v, w, num }; + unique_function(vec![num_type(u, num), num_type(v, num)], bool_type(w)) + }); + }; // isLt or (<) : Num a, Num a -> Bool - add_type( - Symbol::NUM_LT, - unique_function( - vec![num_type(UVAR1, TVAR1), num_type(UVAR2, TVAR1)], - bool_type(UVAR3), - ), - ); + add_num_comparison(Symbol::NUM_LT); // isLte or (<=) : Num a, Num a -> Bool - add_type( - Symbol::NUM_LTE, - unique_function( - vec![num_type(UVAR1, TVAR1), num_type(UVAR2, TVAR1)], - bool_type(UVAR3), - ), - ); + add_num_comparison(Symbol::NUM_LTE); // isGt or (>) : Num a, Num a -> Bool - add_type( - Symbol::NUM_GT, - unique_function( - vec![num_type(UVAR1, TVAR1), num_type(UVAR2, TVAR1)], - bool_type(UVAR3), - ), - ); + add_num_comparison(Symbol::NUM_GT); // isGte or (>=) : Num a, Num a -> Bool - add_type( - Symbol::NUM_GTE, - unique_function( - vec![num_type(UVAR1, TVAR1), num_type(UVAR2, TVAR1)], - bool_type(UVAR3), - ), - ); + add_num_comparison(Symbol::NUM_GTE); // toFloat : Num a -> Float add_type( @@ -375,42 +353,42 @@ pub fn types() -> MutMap { // Int module // isLt or (<) : Num a, Num a -> Bool - add_type( - Symbol::INT_LT, - unique_function(vec![int_type(UVAR1), int_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::INT_LT, { + let_tvars! { u, v, w }; + unique_function(vec![int_type(u), int_type(v)], bool_type(w)) + }); // equals or (==) : Int, Int -> Bool - add_type( - Symbol::INT_EQ_I64, - unique_function(vec![int_type(UVAR1), int_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::INT_EQ_I64, { + let_tvars! { u, v, w }; + unique_function(vec![int_type(u), int_type(v)], bool_type(w)) + }); // not equals or (!=) : Int, Int -> Bool - add_type( - Symbol::INT_NEQ_I64, - unique_function(vec![int_type(UVAR1), int_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::INT_NEQ_I64, { + let_tvars! { u, v, w }; + unique_function(vec![int_type(u), int_type(v)], bool_type(w)) + }); // abs : Int -> Int - add_type( - Symbol::INT_ABS, - unique_function(vec![int_type(UVAR1)], int_type(UVAR2)), - ); + add_type(Symbol::INT_ABS, { + let_tvars! { u, v }; + unique_function(vec![int_type(u)], int_type(v)) + }); - // rem : Int, Int -> Result Int [ DivByZero ]* - add_type( - Symbol::INT_REM, + // rem : Attr * Int, Attr * Int -> Attr * (Result (Attr * Int) (Attr * [ DivByZero ]*)) + add_type(Symbol::INT_REM, { + let_tvars! { star1, star2, star3, star4, star5 }; unique_function( - vec![int_type(UVAR1), int_type(UVAR2)], - result_type(UVAR3, int_type(UVAR4), lift(UVAR5, div_by_zero())), - ), - ); + vec![int_type(star1), int_type(star2)], + result_type(star3, int_type(star4), lift(star5, div_by_zero())), + ) + }); - add_type( - Symbol::INT_REM_UNSAFE, - unique_function(vec![int_type(UVAR1), int_type(UVAR2)], int_type(UVAR3)), - ); + add_type(Symbol::INT_REM_UNSAFE, { + let_tvars! { star1, star2, star3, }; + unique_function(vec![int_type(star1), int_type(star2)], int_type(star3)) + }); // highest : Int add_type(Symbol::INT_HIGHEST, int_type(UVAR1)); @@ -419,92 +397,92 @@ pub fn types() -> MutMap { add_type(Symbol::INT_LOWEST, int_type(UVAR1)); // div or (//) : Int, Int -> Result Int [ DivByZero ]* - add_type( - Symbol::INT_DIV, + add_type(Symbol::INT_DIV, { + let_tvars! { star1, star2, star3, star4, star5 }; unique_function( - vec![int_type(UVAR1), int_type(UVAR2)], - result_type(UVAR3, int_type(UVAR4), lift(UVAR5, div_by_zero())), - ), - ); + vec![int_type(star1), int_type(star2)], + result_type(star3, int_type(star4), lift(star5, div_by_zero())), + ) + }); - add_type( - Symbol::INT_DIV_UNSAFE, - unique_function(vec![int_type(UVAR1), int_type(UVAR2)], int_type(UVAR3)), - ); + add_type(Symbol::INT_DIV_UNSAFE, { + let_tvars! { star1, star2, star3, }; + unique_function(vec![int_type(star1), int_type(star2)], int_type(star3)) + }); // mod : Int, Int -> Int - add_type( - Symbol::INT_MOD, - unique_function(vec![int_type(UVAR1), int_type(UVAR2)], int_type(UVAR3)), - ); + add_type(Symbol::INT_MOD, { + let_tvars! { star1, star2, star3, }; + unique_function(vec![int_type(star1), int_type(star2)], int_type(star3)) + }); // Float module // isGt or (>) : Num a, Num a -> Bool - add_type( - Symbol::FLOAT_GT, - unique_function(vec![float_type(UVAR1), float_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::FLOAT_GT, { + let_tvars! { star1, star2, star3} + unique_function(vec![float_type(star1), float_type(star2)], bool_type(star3)) + }); // eq or (==) : Num a, Num a -> Bool - add_type( - Symbol::FLOAT_EQ, - unique_function(vec![float_type(UVAR1), float_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::FLOAT_EQ, { + let_tvars! { star1, star2, star3} + unique_function(vec![float_type(star1), float_type(star2)], bool_type(star3)) + }); // div : Float, Float -> Float - add_type( - Symbol::FLOAT_DIV, + add_type(Symbol::FLOAT_DIV, { + let_tvars! { star1, star2, star3}; unique_function( - vec![float_type(UVAR1), float_type(UVAR2)], - float_type(UVAR3), - ), - ); + vec![float_type(star1), float_type(star2)], + float_type(star3), + ) + }); // mod : Float, Float -> Float - add_type( - Symbol::FLOAT_MOD, + add_type(Symbol::FLOAT_MOD, { + let_tvars! { star1, star2, star3}; unique_function( - vec![float_type(UVAR1), float_type(UVAR2)], - float_type(UVAR3), - ), - ); - - // sqrt : Float -> Float - add_type( - Symbol::FLOAT_SQRT, - unique_function(vec![float_type(UVAR1)], float_type(UVAR2)), - ); + vec![float_type(star1), float_type(star2)], + float_type(star3), + ) + }); // round : Float -> Int - add_type( - Symbol::FLOAT_ROUND, - unique_function(vec![float_type(UVAR1)], int_type(UVAR2)), - ); + add_type(Symbol::FLOAT_ROUND, { + let_tvars! { star1, star2 }; + unique_function(vec![float_type(star1)], int_type(star2)) + }); + + // sqrt : Float -> Float + add_type(Symbol::FLOAT_SQRT, { + let_tvars! { star1, star2 }; + unique_function(vec![float_type(star1)], float_type(star2)) + }); // abs : Float -> Float - add_type( - Symbol::FLOAT_ABS, - unique_function(vec![float_type(UVAR1)], float_type(UVAR2)), - ); + add_type(Symbol::FLOAT_ABS, { + let_tvars! { star1, star2 }; + unique_function(vec![float_type(star1)], float_type(star2)) + }); // sin : Float -> Float - add_type( - Symbol::FLOAT_SIN, - unique_function(vec![float_type(UVAR1)], float_type(UVAR2)), - ); + add_type(Symbol::FLOAT_SIN, { + let_tvars! { star1, star2 }; + unique_function(vec![float_type(star1)], float_type(star2)) + }); // cos : Float -> Float - add_type( - Symbol::FLOAT_COS, - unique_function(vec![float_type(UVAR1)], float_type(UVAR2)), - ); + add_type(Symbol::FLOAT_COS, { + let_tvars! { star1, star2 }; + unique_function(vec![float_type(star1)], float_type(star2)) + }); // tan : Float -> Float - add_type( - Symbol::FLOAT_TAN, - unique_function(vec![float_type(UVAR1)], float_type(UVAR2)), - ); + add_type(Symbol::FLOAT_TAN, { + let_tvars! { star1, star2 }; + unique_function(vec![float_type(star1)], float_type(star2)) + }); // highest : Float add_type(Symbol::FLOAT_HIGHEST, float_type(UVAR1)); @@ -514,41 +492,47 @@ pub fn types() -> MutMap { // Bool module - // isEq or (==) : a, a -> Attr u Bool - add_type( - Symbol::BOOL_EQ, - unique_function(vec![flex(TVAR1), flex(TVAR1)], bool_type(UVAR3)), - ); + // isEq or (==) : Attr * a, Attr * a -> Attr * Bool + add_type(Symbol::BOOL_EQ, { + let_tvars! { star1, star2, star3, a }; + unique_function( + vec![attr_type(star1, a), attr_type(star2, a)], + bool_type(star3), + ) + }); - // isNeq or (!=) : a, a -> Attr u Bool - add_type( - Symbol::BOOL_NEQ, - unique_function(vec![flex(TVAR1), flex(TVAR1)], bool_type(UVAR3)), - ); + // isNeq or (!=) : Attr * a, Attr * a -> Attr * Bool + add_type(Symbol::BOOL_NEQ, { + let_tvars! { star1, star2, star3, a }; + unique_function( + vec![attr_type(star1, a), attr_type(star2, a)], + bool_type(star3), + ) + }); // and or (&&) : Attr u1 Bool, Attr u2 Bool -> Attr u3 Bool - add_type( - Symbol::BOOL_AND, - unique_function(vec![bool_type(UVAR1), bool_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::BOOL_AND, { + let_tvars! { star1, star2, star3}; + unique_function(vec![bool_type(star1), bool_type(star2)], bool_type(star3)) + }); // or or (||) : Attr u1 Bool, Attr u2 Bool -> Attr u3 Bool - add_type( - Symbol::BOOL_OR, - unique_function(vec![bool_type(UVAR1), bool_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::BOOL_OR, { + let_tvars! { star1, star2, star3}; + unique_function(vec![bool_type(star1), bool_type(star2)], bool_type(star3)) + }); // xor : Attr u1 Bool, Attr u2 Bool -> Attr u3 Bool - add_type( - Symbol::BOOL_XOR, - unique_function(vec![bool_type(UVAR1), bool_type(UVAR2)], bool_type(UVAR3)), - ); + add_type(Symbol::BOOL_XOR, { + let_tvars! { star1, star2, star3}; + unique_function(vec![bool_type(star1), bool_type(star2)], bool_type(star3)) + }); // not : Attr u1 Bool -> Attr u2 Bool - add_type( - Symbol::BOOL_NOT, - unique_function(vec![bool_type(UVAR1)], bool_type(UVAR2)), - ); + add_type(Symbol::BOOL_NOT, { + let_tvars! { star1, star2 }; + unique_function(vec![bool_type(star1)], bool_type(star2)) + }); // List module diff --git a/compiler/solve/tests/test_uniq_solve.rs b/compiler/solve/tests/test_uniq_solve.rs index e60280e7c3..9a0c65e174 100644 --- a/compiler/solve/tests/test_uniq_solve.rs +++ b/compiler/solve/tests/test_uniq_solve.rs @@ -2402,14 +2402,18 @@ mod test_uniq_solve { } #[test] - fn equals() { + fn bool_eq() { infer_eq( - indoc!( - r#" - \a, b -> a == b - "# - ), - "Attr * (a, a -> Attr * Bool)", + "\\a, b -> a == b", + "Attr * (Attr * a, Attr * a -> Attr * Bool)", + ); + } + + #[test] + fn bool_neq() { + infer_eq( + "\\a, b -> a != b", + "Attr * (Attr * a, Attr * a -> Attr * Bool)", ); }