diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d2b45afc9..02ec518c95 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,14 +8,14 @@ env: jobs: prep-dependency-container: name: fmt, clippy, test --release - runs-on: self-hosted + runs-on: [self-hosted] timeout-minutes: 60 env: FORCE_COLOR: 1 steps: - uses: actions/checkout@v2 with: - clean: "false" + clean: "true" - name: Earthly version run: earthly --version diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index b43e329405..1791a8f61b 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -11,6 +11,7 @@ const CompareFn = fn (?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) u8; const Opaque = ?[*]u8; const Inc = fn (?[*]u8) callconv(.C) void; +const IncN = fn (?[*]u8, usize) callconv(.C) void; const Dec = fn (?[*]u8) callconv(.C) void; pub const RocList = extern struct { @@ -615,7 +616,7 @@ pub fn listContains(list: RocList, key: Opaque, key_width: usize, is_eq: EqFn) c return false; } -pub fn listRepeat(count: usize, alignment: usize, element: Opaque, element_width: usize, inc_n_element: Inc) callconv(.C) RocList { +pub fn listRepeat(count: usize, alignment: usize, element: Opaque, element_width: usize, inc_n_element: IncN) callconv(.C) RocList { if (count == 0) { return RocList.empty(); } @@ -624,18 +625,15 @@ pub fn listRepeat(count: usize, alignment: usize, element: Opaque, element_width var output = RocList.allocate(allocator, alignment, count, element_width); if (output.bytes) |target_ptr| { + // increment the element's RC N times + inc_n_element(element, count); + var i: usize = 0; const source = element orelse unreachable; while (i < count) : (i += 1) { @memcpy(target_ptr + i * element_width, source, element_width); } - // TODO do all increments at once! - i = 0; - while (i < count) : (i += 1) { - inc_n_element(element); - } - return output; } else { unreachable; diff --git a/compiler/builtins/bitcode/src/utils.zig b/compiler/builtins/bitcode/src/utils.zig index f36eee65d5..ab3e6951f7 100644 --- a/compiler/builtins/bitcode/src/utils.zig +++ b/compiler/builtins/bitcode/src/utils.zig @@ -1,6 +1,10 @@ const std = @import("std"); const Allocator = std.mem.Allocator; +pub const Inc = fn (?[*]u8) callconv(.C) void; +pub const IncN = fn (?[*]u8, u64) callconv(.C) void; +pub const Dec = fn (?[*]u8) callconv(.C) void; + const REFCOUNT_MAX_ISIZE: comptime isize = 0; const REFCOUNT_ONE_ISIZE: comptime isize = std.math.minInt(isize); pub const REFCOUNT_ONE: usize = @bitCast(usize, REFCOUNT_ONE_ISIZE); diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index e7f40f2b7d..c0d8f9b39e 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -68,28 +68,53 @@ const TOP_LEVEL_CLOSURE_VAR: VarId = VarId::from_u32(5); pub fn types() -> MutMap { let mut types = HashMap::with_capacity_and_hasher(NUM_BUILTIN_IMPORTS, default_hasher()); - let mut add_type = |symbol, typ| { - debug_assert!( - !types.contains_key(&symbol), - "Duplicate type definition for {:?}", - symbol - ); + macro_rules! add_type { + ($symbol:expr, $typ:expr $(,)?) => {{ + debug_assert!( + !types.contains_key(&$symbol), + "Duplicate type definition for {:?}", + $symbol + ); - // TODO instead of using Region::zero for all of these, - // instead use the Region where they were defined in their - // source .roc files! This can give nicer error messages. - types.insert(symbol, (typ, Region::zero())); - }; + // TODO instead of using Region::zero for all of these, + // instead use the Region where they were defined in their + // source .roc files! This can give nicer error messages. + types.insert($symbol, ($typ, Region::zero())); + }}; + } + + macro_rules! add_top_level_function_type { + ($symbol:expr, $arguments:expr, $result:expr $(,)?) => {{ + debug_assert!( + !types.contains_key(&$symbol), + "Duplicate type definition for {:?}", + $symbol + ); + + let ext = Box::new(SolvedType::Flex(TOP_LEVEL_CLOSURE_VAR)); + // in the future, we will enable the line below + // let closure_var = Box::new(SolvedType::TagUnion( + // vec![(TagName::Closure($symbol), vec![])], + // ext, + // )); + let closure_var = ext; + + let typ = SolvedType::Func($arguments, closure_var, $result); + + // TODO instead of using Region::zero for all of these, + // instead use the Region where they were defined in their + // source .roc files! This can give nicer error messages. + types.insert($symbol, (typ, Region::zero())); + }}; + } // Num module // add or (+) : Num a, Num a -> Num a - add_type( + add_top_level_function_type!( Symbol::NUM_ADD, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(num_type(flex(TVAR1))), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(num_type(flex(TVAR1))), ); fn overflow() -> SolvedType { @@ -100,190 +125,171 @@ pub fn types() -> MutMap { } // addChecked : Num a, Num a -> Result (Num a) [ Overflow ]* - add_type( + add_top_level_function_type!( Symbol::NUM_ADD_CHECKED, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(result_type(num_type(flex(TVAR1)), overflow())), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(result_type(num_type(flex(TVAR1)), overflow())), ); // addWrap : Int range, Int range -> Int range - add_type( + add_top_level_function_type!( Symbol::NUM_ADD_WRAP, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // sub or (-) : Num a, Num a -> Num a - add_type( + add_top_level_function_type!( Symbol::NUM_SUB, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(num_type(flex(TVAR1))), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(num_type(flex(TVAR1))), ); // subWrap : Int range, Int range -> Int range - add_type( + add_top_level_function_type!( Symbol::NUM_SUB_WRAP, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // subChecked : Num a, Num a -> Result (Num a) [ Overflow ]* - add_type( + add_top_level_function_type!( Symbol::NUM_SUB_CHECKED, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(result_type(num_type(flex(TVAR1)), overflow())), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(result_type(num_type(flex(TVAR1)), overflow())), ); // mul or (*) : Num a, Num a -> Num a - add_type( + add_top_level_function_type!( Symbol::NUM_MUL, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(num_type(flex(TVAR1))), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(num_type(flex(TVAR1))), ); // mulWrap : Int range, Int range -> Int range - add_type( + add_top_level_function_type!( Symbol::NUM_MUL_WRAP, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // mulChecked : Num a, Num a -> Result (Num a) [ Overflow ]* - add_type( + add_top_level_function_type!( Symbol::NUM_MUL_CHECKED, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(result_type(num_type(flex(TVAR1)), overflow())), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(result_type(num_type(flex(TVAR1)), overflow())), ); // abs : Num a -> Num a - add_type( + add_top_level_function_type!( Symbol::NUM_ABS, - top_level_function(vec![num_type(flex(TVAR1))], Box::new(num_type(flex(TVAR1)))), + vec![num_type(flex(TVAR1))], + Box::new(num_type(flex(TVAR1))) ); // neg : Num a -> Num a - add_type( + add_top_level_function_type!( Symbol::NUM_NEG, - top_level_function(vec![num_type(flex(TVAR1))], Box::new(num_type(flex(TVAR1)))), + vec![num_type(flex(TVAR1))], + Box::new(num_type(flex(TVAR1))) ); // isEq or (==) : a, a -> Bool - add_type( + add_top_level_function_type!( Symbol::BOOL_EQ, - top_level_function(vec![flex(TVAR1), flex(TVAR1)], Box::new(bool_type())), + vec![flex(TVAR1), flex(TVAR1)], + Box::new(bool_type()) ); // isNeq or (!=) : a, a -> Bool - add_type( + add_top_level_function_type!( Symbol::BOOL_NEQ, - top_level_function(vec![flex(TVAR1), flex(TVAR1)], Box::new(bool_type())), + vec![flex(TVAR1), flex(TVAR1)], + Box::new(bool_type()) ); // isLt or (<) : Num a, Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_LT, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(bool_type()), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(bool_type()), ); // isLte or (<=) : Num a, Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_LTE, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(bool_type()), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(bool_type()), ); // isGt or (>) : Num a, Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_GT, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(bool_type()), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(bool_type()), ); // isGte or (>=) : Num a, Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_GTE, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(bool_type()), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(bool_type()), ); // compare : Num a, Num a -> [ LT, EQ, GT ] - add_type( + add_top_level_function_type!( Symbol::NUM_COMPARE, - top_level_function( - vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], - Box::new(ordering_type()), - ), + vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))], + Box::new(ordering_type()), ); // toFloat : Num * -> Float * - add_type( + add_top_level_function_type!( Symbol::NUM_TO_FLOAT, - top_level_function( - vec![num_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR2))), - ), + vec![num_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR2))), ); // isNegative : Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_IS_NEGATIVE, - top_level_function(vec![num_type(flex(TVAR1))], Box::new(bool_type())), + vec![num_type(flex(TVAR1))], + Box::new(bool_type()) ); // isPositive : Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_IS_POSITIVE, - top_level_function(vec![num_type(flex(TVAR1))], Box::new(bool_type())), + vec![num_type(flex(TVAR1))], + Box::new(bool_type()) ); // isZero : Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_IS_ZERO, - top_level_function(vec![num_type(flex(TVAR1))], Box::new(bool_type())), + vec![num_type(flex(TVAR1))], + Box::new(bool_type()) ); // isEven : Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_IS_EVEN, - top_level_function(vec![num_type(flex(TVAR1))], Box::new(bool_type())), + vec![num_type(flex(TVAR1))], + Box::new(bool_type()) ); // isOdd : Num a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_IS_ODD, - top_level_function(vec![num_type(flex(TVAR1))], Box::new(bool_type())), + vec![num_type(flex(TVAR1))], + Box::new(bool_type()) ); // maxInt : Int range - add_type(Symbol::NUM_MAX_INT, int_type(flex(TVAR1))); + add_type!(Symbol::NUM_MAX_INT, int_type(flex(TVAR1))); // minInt : Int range - add_type(Symbol::NUM_MIN_INT, int_type(flex(TVAR1))); + add_type!(Symbol::NUM_MIN_INT, int_type(flex(TVAR1))); // div : Int, Int -> Result Int [ DivByZero ]* let div_by_zero = SolvedType::TagUnion( @@ -291,122 +297,99 @@ pub fn types() -> MutMap { Box::new(SolvedType::Wildcard), ); - add_type( + add_top_level_function_type!( Symbol::NUM_DIV_INT, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(result_type(int_type(flex(TVAR1)), div_by_zero.clone())), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(result_type(int_type(flex(TVAR1)), div_by_zero.clone())), ); // bitwiseAnd : Int a, Int a -> Int a - add_type( + add_top_level_function_type!( Symbol::NUM_BITWISE_AND, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // bitwiseXor : Int a, Int a -> Int a - add_type( + add_top_level_function_type!( Symbol::NUM_BITWISE_XOR, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // bitwiseOr : Int a, Int a -> Int a - add_type( + add_top_level_function_type!( Symbol::NUM_BITWISE_OR, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // shiftLeftBy : Int a, Int a -> Int a - add_type( + add_top_level_function_type!( Symbol::NUM_SHIFT_LEFT, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // shiftRightBy : Int a, Int a -> Int a - add_type( + add_top_level_function_type!( Symbol::NUM_SHIFT_RIGHT, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // shiftRightZfBy : Int a, Int a -> Int a - add_type( + add_top_level_function_type!( Symbol::NUM_SHIFT_RIGHT_ZERO_FILL, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // intCast : Int a -> Int b - add_type( + add_top_level_function_type!( Symbol::NUM_INT_CAST, - top_level_function(vec![int_type(flex(TVAR1))], Box::new(int_type(flex(TVAR2)))), + vec![int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR2))) ); // rem : Int a, Int a -> Result (Int a) [ DivByZero ]* - add_type( + add_top_level_function_type!( Symbol::NUM_REM, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(result_type(int_type(flex(TVAR1)), div_by_zero.clone())), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(result_type(int_type(flex(TVAR1)), div_by_zero.clone())), ); // mod : Int a, Int a -> Result (Int a) [ DivByZero ]* - add_type( + add_top_level_function_type!( Symbol::NUM_MOD_INT, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(result_type(int_type(flex(TVAR1)), div_by_zero.clone())), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(result_type(int_type(flex(TVAR1)), div_by_zero.clone())), ); // isMultipleOf : Int a, Int a -> Bool - add_type( + add_top_level_function_type!( Symbol::NUM_IS_MULTIPLE_OF, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(bool_type()), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(bool_type()), ); // maxI128 : I128 - add_type(Symbol::NUM_MAX_I128, i128_type()); + add_type!(Symbol::NUM_MAX_I128, i128_type()); // Float module // div : Float a, Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_DIV_FLOAT, - top_level_function( - vec![float_type(flex(TVAR1)), float_type(flex(TVAR1))], - Box::new(result_type(float_type(flex(TVAR1)), div_by_zero.clone())), - ), + vec![float_type(flex(TVAR1)), float_type(flex(TVAR1))], + Box::new(result_type(float_type(flex(TVAR1)), div_by_zero.clone())), ); // mod : Float a, Float a -> Result (Float a) [ DivByZero ]* - add_type( + add_top_level_function_type!( Symbol::NUM_MOD_FLOAT, - top_level_function( - vec![float_type(flex(TVAR1)), float_type(flex(TVAR1))], - Box::new(result_type(float_type(flex(TVAR1)), div_by_zero)), - ), + vec![float_type(flex(TVAR1)), float_type(flex(TVAR1))], + Box::new(result_type(float_type(flex(TVAR1)), div_by_zero)), ); // sqrt : Float a -> Float a @@ -415,12 +398,10 @@ pub fn types() -> MutMap { Box::new(SolvedType::Wildcard), ); - add_type( + add_top_level_function_type!( Symbol::NUM_SQRT, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(result_type(float_type(flex(TVAR1)), sqrt_of_negative)), - ), + vec![float_type(flex(TVAR1))], + Box::new(result_type(float_type(flex(TVAR1)), sqrt_of_negative)), ); // log : Float a -> Float a @@ -429,205 +410,184 @@ pub fn types() -> MutMap { Box::new(SolvedType::Wildcard), ); - add_type( + add_top_level_function_type!( Symbol::NUM_LOG, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(result_type(float_type(flex(TVAR1)), log_needs_positive)), - ), + vec![float_type(flex(TVAR1))], + Box::new(result_type(float_type(flex(TVAR1)), log_needs_positive)), ); // round : Float a -> Int b - add_type( + add_top_level_function_type!( Symbol::NUM_ROUND, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR2))), - ), + vec![float_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR2))), ); // sin : Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_SIN, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR1))), - ), + vec![float_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR1))), ); // cos : Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_COS, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR1))), - ), + vec![float_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR1))), ); // tan : Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_TAN, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR1))), - ), + vec![float_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR1))), ); // maxFloat : Float a - add_type(Symbol::NUM_MAX_FLOAT, float_type(flex(TVAR1))); + add_type!(Symbol::NUM_MAX_FLOAT, float_type(flex(TVAR1))); // minFloat : Float a - add_type(Symbol::NUM_MIN_FLOAT, float_type(flex(TVAR1))); + add_type!(Symbol::NUM_MIN_FLOAT, float_type(flex(TVAR1))); // pow : Float a, Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_POW, - top_level_function( - vec![float_type(flex(TVAR1)), float_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR1))), - ), + vec![float_type(flex(TVAR1)), float_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR1))), ); // ceiling : Float a -> Int b - add_type( + add_top_level_function_type!( Symbol::NUM_CEILING, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR2))), - ), + vec![float_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR2))), ); // powInt : Int a, Int a -> Int a - add_type( + add_top_level_function_type!( Symbol::NUM_POW_INT, - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR1))), - ), + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR1))), ); // floor : Float a -> Int b - add_type( + add_top_level_function_type!( Symbol::NUM_FLOOR, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(int_type(flex(TVAR2))), - ), + vec![float_type(flex(TVAR1))], + Box::new(int_type(flex(TVAR2))), ); // atan : Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_ATAN, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR1))), - ), + vec![float_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR1))), ); // acos : Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_ACOS, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR1))), - ), + vec![float_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR1))), ); // asin : Float a -> Float a - add_type( + add_top_level_function_type!( Symbol::NUM_ASIN, - top_level_function( - vec![float_type(flex(TVAR1))], - Box::new(float_type(flex(TVAR1))), - ), + vec![float_type(flex(TVAR1))], + Box::new(float_type(flex(TVAR1))), ); // Bool module // and : Bool, Bool -> Bool - add_type( + add_top_level_function_type!( Symbol::BOOL_AND, - top_level_function(vec![bool_type(), bool_type()], Box::new(bool_type())), + vec![bool_type(), bool_type()], + Box::new(bool_type()) ); // or : Bool, Bool -> Bool - add_type( + add_top_level_function_type!( Symbol::BOOL_OR, - top_level_function(vec![bool_type(), bool_type()], Box::new(bool_type())), + vec![bool_type(), bool_type()], + Box::new(bool_type()) ); // xor : Bool, Bool -> Bool - add_type( + add_top_level_function_type!( Symbol::BOOL_XOR, - top_level_function(vec![bool_type(), bool_type()], Box::new(bool_type())), + vec![bool_type(), bool_type()], + Box::new(bool_type()) ); // not : Bool -> Bool - add_type( - Symbol::BOOL_NOT, - top_level_function(vec![bool_type()], Box::new(bool_type())), - ); + add_top_level_function_type!(Symbol::BOOL_NOT, vec![bool_type()], Box::new(bool_type())); // Str module // Str.split : Str, Str -> List Str - add_type( + add_top_level_function_type!( Symbol::STR_SPLIT, - top_level_function( - vec![str_type(), str_type()], - Box::new(list_type(str_type())), - ), + vec![str_type(), str_type()], + Box::new(list_type(str_type())), ); // Str.concat : Str, Str -> Str - add_type( + add_top_level_function_type!( Symbol::STR_CONCAT, - top_level_function(vec![str_type(), str_type()], Box::new(str_type())), + vec![str_type(), str_type()], + Box::new(str_type()), ); // Str.joinWith : List Str, Str -> Str - add_type( + add_top_level_function_type!( Symbol::STR_JOIN_WITH, - top_level_function( - vec![list_type(str_type()), str_type()], - Box::new(str_type()), - ), + vec![list_type(str_type()), str_type()], + Box::new(str_type()), ); // isEmpty : Str -> Bool - add_type( + add_top_level_function_type!( Symbol::STR_IS_EMPTY, - top_level_function(vec![str_type()], Box::new(bool_type())), + vec![str_type()], + Box::new(bool_type()) ); // startsWith : Str, Str -> Bool - add_type( + add_top_level_function_type!( Symbol::STR_STARTS_WITH, - top_level_function(vec![str_type(), str_type()], Box::new(bool_type())), + vec![str_type(), str_type()], + Box::new(bool_type()) ); // startsWithCodePoint : Str, U32 -> Bool - add_type( + add_top_level_function_type!( Symbol::STR_STARTS_WITH_CODE_POINT, - top_level_function(vec![str_type(), u32_type()], Box::new(bool_type())), + vec![str_type(), u32_type()], + Box::new(bool_type()) ); // endsWith : Str, Str -> Bool - add_type( + add_top_level_function_type!( Symbol::STR_ENDS_WITH, - top_level_function(vec![str_type(), str_type()], Box::new(bool_type())), + vec![str_type(), str_type()], + Box::new(bool_type()) ); // countGraphemes : Str -> Nat - add_type( + add_top_level_function_type!( Symbol::STR_COUNT_GRAPHEMES, - top_level_function(vec![str_type()], Box::new(nat_type())), + vec![str_type()], + Box::new(nat_type()) ); // fromInt : Int a -> Str - add_type( + add_top_level_function_type!( Symbol::STR_FROM_INT, - top_level_function(vec![int_type(flex(TVAR1))], Box::new(str_type())), + vec![int_type(flex(TVAR1))], + Box::new(str_type()) ); // fromUtf8 : List U8 -> Result Str [ BadUtf8 Utf8Problem ]* @@ -640,24 +600,24 @@ pub fn types() -> MutMap { Box::new(SolvedType::Wildcard), ); - add_type( + add_top_level_function_type!( Symbol::STR_FROM_UTF8, - top_level_function( - vec![list_type(u8_type())], - Box::new(result_type(str_type(), bad_utf8)), - ), + vec![list_type(u8_type())], + Box::new(result_type(str_type(), bad_utf8)), ); // toBytes : Str -> List U8 - add_type( + add_top_level_function_type!( Symbol::STR_TO_BYTES, - top_level_function(vec![str_type()], Box::new(list_type(u8_type()))), + vec![str_type()], + Box::new(list_type(u8_type())) ); // fromFloat : Float a -> Str - add_type( + add_top_level_function_type!( Symbol::STR_FROM_FLOAT, - top_level_function(vec![float_type(flex(TVAR1))], Box::new(str_type())), + vec![float_type(flex(TVAR1))], + Box::new(str_type()) ); // List module @@ -668,12 +628,10 @@ pub fn types() -> MutMap { Box::new(SolvedType::Wildcard), ); - add_type( + add_top_level_function_type!( Symbol::LIST_GET, - top_level_function( - vec![list_type(flex(TVAR1)), nat_type()], - Box::new(result_type(flex(TVAR1), index_out_of_bounds)), - ), + vec![list_type(flex(TVAR1)), nat_type()], + Box::new(result_type(flex(TVAR1), index_out_of_bounds)), ); // first : List elem -> Result elem [ ListWasEmpty ]* @@ -682,92 +640,74 @@ pub fn types() -> MutMap { Box::new(SolvedType::Wildcard), ); - add_type( + add_top_level_function_type!( Symbol::LIST_FIRST, - top_level_function( - vec![list_type(flex(TVAR1))], - Box::new(result_type(flex(TVAR1), list_was_empty.clone())), - ), + vec![list_type(flex(TVAR1))], + Box::new(result_type(flex(TVAR1), list_was_empty.clone())), ); // last : List elem -> Result elem [ ListWasEmpty ]* - add_type( + add_top_level_function_type!( Symbol::LIST_LAST, - top_level_function( - vec![list_type(flex(TVAR1))], - Box::new(result_type(flex(TVAR1), list_was_empty)), - ), + vec![list_type(flex(TVAR1))], + Box::new(result_type(flex(TVAR1), list_was_empty)), ); // set : List elem, Nat, elem -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_SET, - top_level_function( - vec![list_type(flex(TVAR1)), nat_type(), flex(TVAR1)], - Box::new(list_type(flex(TVAR1))), - ), + vec![list_type(flex(TVAR1)), nat_type(), flex(TVAR1)], + Box::new(list_type(flex(TVAR1))), ); // concat : List elem, List elem -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_CONCAT, - top_level_function( - vec![list_type(flex(TVAR1)), list_type(flex(TVAR1))], - Box::new(list_type(flex(TVAR1))), - ), + vec![list_type(flex(TVAR1)), list_type(flex(TVAR1))], + Box::new(list_type(flex(TVAR1))), ); // contains : List elem, elem -> Bool - add_type( + add_top_level_function_type!( Symbol::LIST_CONTAINS, - top_level_function( - vec![list_type(flex(TVAR1)), flex(TVAR1)], - Box::new(bool_type()), - ), + vec![list_type(flex(TVAR1)), flex(TVAR1)], + Box::new(bool_type()), ); // sum : List (Num a) -> Num a - add_type( + add_top_level_function_type!( Symbol::LIST_SUM, - top_level_function( - vec![list_type(num_type(flex(TVAR1)))], - Box::new(num_type(flex(TVAR1))), - ), + vec![list_type(num_type(flex(TVAR1)))], + Box::new(num_type(flex(TVAR1))), ); // product : List (Num a) -> Num a - add_type( + add_top_level_function_type!( Symbol::LIST_PRODUCT, - top_level_function( - vec![list_type(num_type(flex(TVAR1)))], - Box::new(num_type(flex(TVAR1))), - ), + vec![list_type(num_type(flex(TVAR1)))], + Box::new(num_type(flex(TVAR1))), ); // walk : List elem, (elem -> accum -> accum), accum -> accum - add_type( + add_top_level_function_type!( Symbol::LIST_WALK, - top_level_function( - vec![ - list_type(flex(TVAR1)), - closure(vec![flex(TVAR1), flex(TVAR2)], TVAR3, Box::new(flex(TVAR2))), - flex(TVAR2), - ], - Box::new(flex(TVAR2)), - ), + vec![ + list_type(flex(TVAR1)), + closure(vec![flex(TVAR1), flex(TVAR2)], TVAR3, Box::new(flex(TVAR2))), + flex(TVAR2), + ], + Box::new(flex(TVAR2)), ); // walkBackwards : List elem, (elem -> accum -> accum), accum -> accum - add_type( + add_top_level_function_type!( Symbol::LIST_WALK_BACKWARDS, - top_level_function( - vec![ - list_type(flex(TVAR1)), - closure(vec![flex(TVAR1), flex(TVAR2)], TVAR3, Box::new(flex(TVAR2))), - flex(TVAR2), - ], - Box::new(flex(TVAR2)), - ), + vec![ + list_type(flex(TVAR1)), + closure(vec![flex(TVAR1), flex(TVAR2)], TVAR3, Box::new(flex(TVAR2))), + flex(TVAR2), + ], + Box::new(flex(TVAR2)), ); fn until_type(content: SolvedType) -> SolvedType { @@ -782,38 +722,35 @@ pub fn types() -> MutMap { } // walkUntil : List elem, (elem -> accum -> [ Continue accum, Stop accum ]), accum -> accum - add_type( + add_top_level_function_type!( Symbol::LIST_WALK_UNTIL, - top_level_function( - vec![ - list_type(flex(TVAR1)), - closure( - vec![flex(TVAR1), flex(TVAR2)], - TVAR3, - Box::new(until_type(flex(TVAR2))), - ), - flex(TVAR2), - ], - Box::new(flex(TVAR2)), - ), + vec![ + list_type(flex(TVAR1)), + closure( + vec![flex(TVAR1), flex(TVAR2)], + TVAR3, + Box::new(until_type(flex(TVAR2))), + ), + flex(TVAR2), + ], + Box::new(flex(TVAR2)), ); // keepIf : List elem, (elem -> Bool) -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_KEEP_IF, - top_level_function( - vec![ - list_type(flex(TVAR1)), - closure(vec![flex(TVAR1)], TVAR2, Box::new(bool_type())), - ], - Box::new(list_type(flex(TVAR1))), - ), + vec![ + list_type(flex(TVAR1)), + closure(vec![flex(TVAR1)], TVAR2, Box::new(bool_type())), + ], + Box::new(list_type(flex(TVAR1))), ); // keepOks : List before, (before -> Result after *) -> List after - add_type(Symbol::LIST_KEEP_OKS, { + { let_tvars! { star, cvar, before, after}; - top_level_function( + add_top_level_function_type!( + Symbol::LIST_KEEP_OKS, vec![ list_type(flex(before)), closure( @@ -824,12 +761,14 @@ pub fn types() -> MutMap { ], Box::new(list_type(flex(after))), ) - }); + }; // keepErrs: List before, (before -> Result * after) -> List after - add_type(Symbol::LIST_KEEP_ERRS, { + { let_tvars! { star, cvar, before, after}; - top_level_function( + + add_top_level_function_type!( + Symbol::LIST_KEEP_ERRS, vec![ list_type(flex(before)), closure( @@ -840,44 +779,43 @@ pub fn types() -> MutMap { ], Box::new(list_type(flex(after))), ) - }); + }; // range : Int a, Int a -> List (Int a) - add_type(Symbol::LIST_RANGE, { - top_level_function( - vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], - Box::new(list_type(int_type(flex(TVAR1)))), - ) - }); + add_top_level_function_type!( + Symbol::LIST_RANGE, + vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))], + Box::new(list_type(int_type(flex(TVAR1)))), + ); // map : List before, (before -> after) -> List after - add_type( + add_top_level_function_type!( Symbol::LIST_MAP, - top_level_function( - vec![ - list_type(flex(TVAR1)), - closure(vec![flex(TVAR1)], TVAR3, Box::new(flex(TVAR2))), - ], - Box::new(list_type(flex(TVAR2))), - ), + vec![ + list_type(flex(TVAR1)), + closure(vec![flex(TVAR1)], TVAR3, Box::new(flex(TVAR2))), + ], + Box::new(list_type(flex(TVAR2))), ); // mapWithIndex : List before, (Nat, before -> after) -> List after - add_type(Symbol::LIST_MAP_WITH_INDEX, { + { let_tvars! { cvar, before, after}; - top_level_function( + add_top_level_function_type!( + Symbol::LIST_MAP_WITH_INDEX, vec![ list_type(flex(before)), closure(vec![nat_type(), flex(before)], cvar, Box::new(flex(after))), ], Box::new(list_type(flex(after))), ) - }); + }; // map2 : List a, List b, (a, b -> c) -> List c - add_type(Symbol::LIST_MAP2, { + { let_tvars! {a, b, c, cvar}; - top_level_function( + add_top_level_function_type!( + Symbol::LIST_MAP2, vec![ list_type(flex(a)), list_type(flex(b)), @@ -885,13 +823,14 @@ pub fn types() -> MutMap { ], Box::new(list_type(flex(c))), ) - }); + }; - // map3 : List a, List b, List c, (a, b, c -> d) -> List d - add_type(Symbol::LIST_MAP3, { + { let_tvars! {a, b, c, d, cvar}; - top_level_function( + // map3 : List a, List b, List c, (a, b, c -> d) -> List d + add_top_level_function_type!( + Symbol::LIST_MAP3, vec![ list_type(flex(a)), list_type(flex(b)), @@ -900,114 +839,102 @@ pub fn types() -> MutMap { ], Box::new(list_type(flex(d))), ) - }); + }; // append : List elem, elem -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_APPEND, - top_level_function( - vec![list_type(flex(TVAR1)), flex(TVAR1)], - Box::new(list_type(flex(TVAR1))), - ), + vec![list_type(flex(TVAR1)), flex(TVAR1)], + Box::new(list_type(flex(TVAR1))), ); // prepend : List elem, elem -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_PREPEND, - top_level_function( - vec![list_type(flex(TVAR1)), flex(TVAR1)], - Box::new(list_type(flex(TVAR1))), - ), + vec![list_type(flex(TVAR1)), flex(TVAR1)], + Box::new(list_type(flex(TVAR1))), ); // join : List (List elem) -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_JOIN, - top_level_function( - vec![list_type(list_type(flex(TVAR1)))], - Box::new(list_type(flex(TVAR1))), - ), + vec![list_type(list_type(flex(TVAR1)))], + Box::new(list_type(flex(TVAR1))), ); // single : a -> List a - add_type( + add_top_level_function_type!( Symbol::LIST_SINGLE, - top_level_function(vec![flex(TVAR1)], Box::new(list_type(flex(TVAR1)))), + vec![flex(TVAR1)], + Box::new(list_type(flex(TVAR1))) ); // repeat : Nat, elem -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_REPEAT, - top_level_function( - vec![nat_type(), flex(TVAR1)], - Box::new(list_type(flex(TVAR1))), - ), + vec![nat_type(), flex(TVAR1)], + Box::new(list_type(flex(TVAR1))), ); // reverse : List elem -> List elem - add_type( + add_top_level_function_type!( Symbol::LIST_REVERSE, - top_level_function( - vec![list_type(flex(TVAR1))], - Box::new(list_type(flex(TVAR1))), - ), + vec![list_type(flex(TVAR1))], + Box::new(list_type(flex(TVAR1))), ); // len : List * -> Nat - add_type( + add_top_level_function_type!( Symbol::LIST_LEN, - top_level_function(vec![list_type(flex(TVAR1))], Box::new(nat_type())), + vec![list_type(flex(TVAR1))], + Box::new(nat_type()) ); // isEmpty : List * -> Bool - add_type( + add_top_level_function_type!( Symbol::LIST_IS_EMPTY, - top_level_function(vec![list_type(flex(TVAR1))], Box::new(bool_type())), + vec![list_type(flex(TVAR1))], + Box::new(bool_type()) ); // sortWith : List a, (a, a -> Ordering) -> List a - add_type( + add_top_level_function_type!( Symbol::LIST_SORT_WITH, - top_level_function( - vec![ - list_type(flex(TVAR1)), - closure( - vec![flex(TVAR1), flex(TVAR1)], - TVAR2, - Box::new(ordering_type()), - ), - ], - Box::new(list_type(flex(TVAR1))), - ), + vec![ + list_type(flex(TVAR1)), + closure( + vec![flex(TVAR1), flex(TVAR1)], + TVAR2, + Box::new(ordering_type()), + ), + ], + Box::new(list_type(flex(TVAR1))), ); // Dict module // Dict.hashTestOnly : Nat, v -> Nat - add_type( + add_top_level_function_type!( Symbol::DICT_TEST_HASH, - top_level_function(vec![u64_type(), flex(TVAR2)], Box::new(nat_type())), + vec![u64_type(), flex(TVAR2)], + Box::new(nat_type()) ); // len : Dict * * -> Nat - add_type( + add_top_level_function_type!( Symbol::DICT_LEN, - top_level_function( - vec![dict_type(flex(TVAR1), flex(TVAR2))], - Box::new(nat_type()), - ), + vec![dict_type(flex(TVAR1), flex(TVAR2))], + Box::new(nat_type()), ); // empty : Dict * * - add_type(Symbol::DICT_EMPTY, dict_type(flex(TVAR1), flex(TVAR2))); + add_type!(Symbol::DICT_EMPTY, dict_type(flex(TVAR1), flex(TVAR2))); // single : k, v -> Dict k v - add_type( + add_top_level_function_type!( Symbol::DICT_SINGLE, - top_level_function( - vec![flex(TVAR1), flex(TVAR2)], - Box::new(dict_type(flex(TVAR1), flex(TVAR2))), - ), + vec![flex(TVAR1), flex(TVAR2)], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), ); // get : Dict k v, k -> Result v [ KeyNotFound ]* @@ -1016,264 +943,220 @@ pub fn types() -> MutMap { Box::new(SolvedType::Wildcard), ); - add_type( + add_top_level_function_type!( Symbol::DICT_GET, - top_level_function( - vec![dict_type(flex(TVAR1), flex(TVAR2)), flex(TVAR1)], - Box::new(result_type(flex(TVAR2), key_not_found)), - ), + vec![dict_type(flex(TVAR1), flex(TVAR2)), flex(TVAR1)], + Box::new(result_type(flex(TVAR2), key_not_found)), ); // Dict.insert : Dict k v, k, v -> Dict k v - add_type( + add_top_level_function_type!( Symbol::DICT_INSERT, - top_level_function( - vec![ - dict_type(flex(TVAR1), flex(TVAR2)), - flex(TVAR1), - flex(TVAR2), - ], - Box::new(dict_type(flex(TVAR1), flex(TVAR2))), - ), + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + flex(TVAR1), + flex(TVAR2), + ], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), ); // Dict.remove : Dict k v, k -> Dict k v - add_type( + add_top_level_function_type!( Symbol::DICT_REMOVE, - top_level_function( - vec![dict_type(flex(TVAR1), flex(TVAR2)), flex(TVAR1)], - Box::new(dict_type(flex(TVAR1), flex(TVAR2))), - ), + vec![dict_type(flex(TVAR1), flex(TVAR2)), flex(TVAR1)], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), ); // Dict.contains : Dict k v, k -> Bool - add_type( + add_top_level_function_type!( Symbol::DICT_CONTAINS, - top_level_function( - vec![dict_type(flex(TVAR1), flex(TVAR2)), flex(TVAR1)], - Box::new(bool_type()), - ), + vec![dict_type(flex(TVAR1), flex(TVAR2)), flex(TVAR1)], + Box::new(bool_type()), ); // Dict.keys : Dict k v -> List k - add_type( + add_top_level_function_type!( Symbol::DICT_KEYS, - top_level_function( - vec![dict_type(flex(TVAR1), flex(TVAR2))], - Box::new(list_type(flex(TVAR1))), - ), + vec![dict_type(flex(TVAR1), flex(TVAR2))], + Box::new(list_type(flex(TVAR1))), ); // Dict.values : Dict k v -> List v - add_type( + add_top_level_function_type!( Symbol::DICT_VALUES, - top_level_function( - vec![dict_type(flex(TVAR1), flex(TVAR2))], - Box::new(list_type(flex(TVAR2))), - ), + vec![dict_type(flex(TVAR1), flex(TVAR2))], + Box::new(list_type(flex(TVAR2))), ); // Dict.union : Dict k v, Dict k v -> Dict k v - add_type( + add_top_level_function_type!( Symbol::DICT_UNION, - top_level_function( - vec![ - dict_type(flex(TVAR1), flex(TVAR2)), - dict_type(flex(TVAR1), flex(TVAR2)), - ], - Box::new(dict_type(flex(TVAR1), flex(TVAR2))), - ), + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + dict_type(flex(TVAR1), flex(TVAR2)), + ], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), ); // Dict.intersection : Dict k v, Dict k v -> Dict k v - add_type( + add_top_level_function_type!( Symbol::DICT_INTERSECTION, - top_level_function( - vec![ - dict_type(flex(TVAR1), flex(TVAR2)), - dict_type(flex(TVAR1), flex(TVAR2)), - ], - Box::new(dict_type(flex(TVAR1), flex(TVAR2))), - ), + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + dict_type(flex(TVAR1), flex(TVAR2)), + ], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), ); // Dict.difference : Dict k v, Dict k v -> Dict k v - add_type( + add_top_level_function_type!( Symbol::DICT_DIFFERENCE, - top_level_function( - vec![ - dict_type(flex(TVAR1), flex(TVAR2)), - dict_type(flex(TVAR1), flex(TVAR2)), - ], - Box::new(dict_type(flex(TVAR1), flex(TVAR2))), - ), + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + dict_type(flex(TVAR1), flex(TVAR2)), + ], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), ); // Dict.walk : Dict k v, (k, v, accum -> accum), accum -> accum - add_type( + add_top_level_function_type!( Symbol::DICT_WALK, - top_level_function( - vec![ - dict_type(flex(TVAR1), flex(TVAR2)), - closure( - vec![flex(TVAR1), flex(TVAR2), flex(TVAR3)], - TVAR4, - Box::new(flex(TVAR3)), - ), - flex(TVAR3), - ], - Box::new(flex(TVAR3)), - ), + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + closure( + vec![flex(TVAR1), flex(TVAR2), flex(TVAR3)], + TVAR4, + Box::new(flex(TVAR3)), + ), + flex(TVAR3), + ], + Box::new(flex(TVAR3)), ); // Set module // empty : Set a - add_type(Symbol::SET_EMPTY, set_type(flex(TVAR1))); + add_type!(Symbol::SET_EMPTY, set_type(flex(TVAR1))); // single : a -> Set a - add_type( + add_top_level_function_type!( Symbol::SET_SINGLE, - top_level_function(vec![flex(TVAR1)], Box::new(set_type(flex(TVAR1)))), + vec![flex(TVAR1)], + Box::new(set_type(flex(TVAR1))) ); // len : Set * -> Nat - add_type( + add_top_level_function_type!( Symbol::SET_LEN, - top_level_function(vec![set_type(flex(TVAR1))], Box::new(nat_type())), + vec![set_type(flex(TVAR1))], + Box::new(nat_type()) ); // toList : Set a -> List a - add_type( + add_top_level_function_type!( Symbol::SET_TO_LIST, - top_level_function( - vec![set_type(flex(TVAR1))], - Box::new(list_type(flex(TVAR1))), - ), + vec![set_type(flex(TVAR1))], + Box::new(list_type(flex(TVAR1))), ); - // fromList : Set a -> List a - add_type( + // fromList : List a -> Set a + add_top_level_function_type!( Symbol::SET_FROM_LIST, - top_level_function( - vec![list_type(flex(TVAR1))], - Box::new(set_type(flex(TVAR1))), - ), + vec![list_type(flex(TVAR1))], + Box::new(set_type(flex(TVAR1))), ); // union : Set a, Set a -> Set a - add_type( + add_top_level_function_type!( Symbol::SET_UNION, - top_level_function( - vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], - Box::new(set_type(flex(TVAR1))), - ), + vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], + Box::new(set_type(flex(TVAR1))), ); // difference : Set a, Set a -> Set a - add_type( + add_top_level_function_type!( Symbol::SET_DIFFERENCE, - top_level_function( - vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], - Box::new(set_type(flex(TVAR1))), - ), + vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], + Box::new(set_type(flex(TVAR1))), ); // intersection : Set a, Set a -> Set a - add_type( + add_top_level_function_type!( Symbol::SET_INTERSECTION, - top_level_function( - vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], - Box::new(set_type(flex(TVAR1))), - ), + vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], + Box::new(set_type(flex(TVAR1))), ); // Set.walk : Set a, (a, b -> b), b -> b - add_type( + add_top_level_function_type!( Symbol::SET_WALK, - top_level_function( - vec![ - set_type(flex(TVAR1)), - closure(vec![flex(TVAR1), flex(TVAR2)], TVAR3, Box::new(flex(TVAR2))), - flex(TVAR2), - ], - Box::new(flex(TVAR2)), - ), + vec![ + set_type(flex(TVAR1)), + closure(vec![flex(TVAR1), flex(TVAR2)], TVAR3, Box::new(flex(TVAR2))), + flex(TVAR2), + ], + Box::new(flex(TVAR2)), ); - add_type( + add_top_level_function_type!( Symbol::SET_INSERT, - top_level_function( - vec![set_type(flex(TVAR1)), flex(TVAR1)], - Box::new(set_type(flex(TVAR1))), - ), + vec![set_type(flex(TVAR1)), flex(TVAR1)], + Box::new(set_type(flex(TVAR1))), ); - add_type( + add_top_level_function_type!( Symbol::SET_REMOVE, - top_level_function( - vec![set_type(flex(TVAR1)), flex(TVAR1)], - Box::new(set_type(flex(TVAR1))), - ), + vec![set_type(flex(TVAR1)), flex(TVAR1)], + Box::new(set_type(flex(TVAR1))), ); - add_type( + add_top_level_function_type!( Symbol::SET_CONTAINS, - top_level_function( - vec![set_type(flex(TVAR1)), flex(TVAR1)], - Box::new(bool_type()), - ), + vec![set_type(flex(TVAR1)), flex(TVAR1)], + Box::new(bool_type()), ); // Result module // map : Result a err, (a -> b) -> Result b err - add_type( + add_top_level_function_type!( Symbol::RESULT_MAP, - top_level_function( - vec![ - result_type(flex(TVAR1), flex(TVAR3)), - closure(vec![flex(TVAR1)], TVAR4, Box::new(flex(TVAR2))), - ], - Box::new(result_type(flex(TVAR2), flex(TVAR3))), - ), + vec![ + result_type(flex(TVAR1), flex(TVAR3)), + closure(vec![flex(TVAR1)], TVAR4, Box::new(flex(TVAR2))), + ], + Box::new(result_type(flex(TVAR2), flex(TVAR3))), ); // mapErr : Result a x, (x -> y) -> Result a x - add_type( + add_top_level_function_type!( Symbol::RESULT_MAP_ERR, - top_level_function( - vec![ - result_type(flex(TVAR1), flex(TVAR3)), - closure(vec![flex(TVAR3)], TVAR4, Box::new(flex(TVAR2))), - ], - Box::new(result_type(flex(TVAR1), flex(TVAR2))), - ), + vec![ + result_type(flex(TVAR1), flex(TVAR3)), + closure(vec![flex(TVAR3)], TVAR4, Box::new(flex(TVAR2))), + ], + Box::new(result_type(flex(TVAR1), flex(TVAR2))), ); // after : Result a err, (a -> Result b err) -> Result b err - add_type( + add_top_level_function_type!( Symbol::RESULT_AFTER, - top_level_function( - vec![ - result_type(flex(TVAR1), flex(TVAR3)), - closure( - vec![flex(TVAR1)], - TVAR4, - Box::new(result_type(flex(TVAR2), flex(TVAR3))), - ), - ], - Box::new(result_type(flex(TVAR2), flex(TVAR3))), - ), + vec![ + result_type(flex(TVAR1), flex(TVAR3)), + closure( + vec![flex(TVAR1)], + TVAR4, + Box::new(result_type(flex(TVAR2), flex(TVAR3))), + ), + ], + Box::new(result_type(flex(TVAR2), flex(TVAR3))), ); // withDefault : Result a x, a -> a - add_type( + add_top_level_function_type!( Symbol::RESULT_WITH_DEFAULT, - top_level_function( - vec![result_type(flex(TVAR1), flex(TVAR3)), flex(TVAR1)], - Box::new(flex(TVAR1)), - ), + vec![result_type(flex(TVAR1), flex(TVAR3)), flex(TVAR1)], + Box::new(flex(TVAR1)), ); types @@ -1284,15 +1167,6 @@ fn flex(tvar: VarId) -> SolvedType { SolvedType::Flex(tvar) } -#[inline(always)] -fn top_level_function(arguments: Vec, ret: Box) -> SolvedType { - SolvedType::Func( - arguments, - Box::new(SolvedType::Flex(TOP_LEVEL_CLOSURE_VAR)), - ret, - ) -} - #[inline(always)] fn closure(arguments: Vec, closure_var: VarId, ret: Box) -> SolvedType { SolvedType::Func(arguments, Box::new(SolvedType::Flex(closure_var)), ret) diff --git a/compiler/gen/src/llvm/bitcode.rs b/compiler/gen/src/llvm/bitcode.rs index b80f51375e..8545e8770a 100644 --- a/compiler/gen/src/llvm/bitcode.rs +++ b/compiler/gen/src/llvm/bitcode.rs @@ -2,7 +2,9 @@ use crate::debug_info_init; use crate::llvm::build::{set_name, Env, C_CALL_CONV, FAST_CALL_CONV}; use crate::llvm::convert::basic_type_from_layout; -use crate::llvm::refcounting::{decrement_refcount_layout, increment_refcount_layout, Mode}; +use crate::llvm::refcounting::{ + decrement_refcount_layout, increment_n_refcount_layout, increment_refcount_layout, +}; use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::types::{BasicType, BasicTypeEnum}; use inkwell::values::{BasicValueEnum, CallSiteValue, FunctionValue, InstructionValue}; @@ -204,21 +206,28 @@ fn build_transform_caller_help<'a, 'ctx, 'env>( function_value } +enum Mode { + Inc, + IncN, + Dec, +} + +/// a functin that accepts two arguments: the value to increment, and an amount to increment by pub fn build_inc_n_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, layout: &Layout<'a>, - n: u64, ) -> FunctionValue<'ctx> { - build_rc_wrapper(env, layout_ids, layout, Mode::Inc(n)) + build_rc_wrapper(env, layout_ids, layout, Mode::IncN) } +/// a functin that accepts two arguments: the value to increment; increments by 1 pub fn build_inc_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, layout: &Layout<'a>, ) -> FunctionValue<'ctx> { - build_rc_wrapper(env, layout_ids, layout, Mode::Inc(1)) + build_rc_wrapper(env, layout_ids, layout, Mode::Inc) } pub fn build_dec_wrapper<'a, 'ctx, 'env>( @@ -229,7 +238,7 @@ pub fn build_dec_wrapper<'a, 'ctx, 'env>( build_rc_wrapper(env, layout_ids, layout, Mode::Dec) } -pub fn build_rc_wrapper<'a, 'ctx, 'env>( +fn build_rc_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, layout: &Layout<'a>, @@ -244,7 +253,8 @@ pub fn build_rc_wrapper<'a, 'ctx, 'env>( .to_symbol_string(symbol, &env.interns); let fn_name = match rc_operation { - Mode::Inc(n) => format!("{}_inc_{}", fn_name, n), + Mode::IncN => format!("{}_inc_n", fn_name), + Mode::Inc => format!("{}_inc", fn_name), Mode::Dec => format!("{}_dec", fn_name), }; @@ -253,12 +263,20 @@ pub fn build_rc_wrapper<'a, 'ctx, 'env>( None => { let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); - let function_value = crate::llvm::refcounting::build_header_help( - env, - &fn_name, - env.context.void_type().into(), - &[arg_type.into()], - ); + let function_value = match rc_operation { + Mode::Inc | Mode::Dec => crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.void_type().into(), + &[arg_type.into()], + ), + Mode::IncN => crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.void_type().into(), + &[arg_type.into(), env.ptr_int().into()], + ), + }; let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); debug_assert!(kind_id > 0); @@ -285,9 +303,16 @@ pub fn build_rc_wrapper<'a, 'ctx, 'env>( let value = env.builder.build_load(value_cast, "load_opaque"); match rc_operation { - Mode::Inc(n) => { + Mode::Inc => { + let n = 1; increment_refcount_layout(env, function_value, layout_ids, n, value, layout); } + Mode::IncN => { + let n = it.next().unwrap().into_int_value(); + set_name(n.into(), Symbol::ARG_2.ident_string(&env.interns)); + + increment_n_refcount_layout(env, function_value, layout_ids, n, value, layout); + } Mode::Dec => { decrement_refcount_layout(env, function_value, layout_ids, value, layout); } diff --git a/compiler/gen/src/llvm/build_dict.rs b/compiler/gen/src/llvm/build_dict.rs index 7ef54fd055..b282f77ddd 100644 --- a/compiler/gen/src/llvm/build_dict.rs +++ b/compiler/gen/src/llvm/build_dict.rs @@ -397,9 +397,16 @@ pub fn dict_elements_rc<'a, 'ctx, 'env>( let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); - use crate::llvm::bitcode::build_rc_wrapper; - let inc_key_fn = build_rc_wrapper(env, layout_ids, key_layout, rc_operation); - let inc_value_fn = build_rc_wrapper(env, layout_ids, value_layout, rc_operation); + let (key_fn, value_fn) = match rc_operation { + Mode::Inc => ( + build_inc_wrapper(env, layout_ids, key_layout), + build_inc_wrapper(env, layout_ids, value_layout), + ), + Mode::Dec => ( + build_dec_wrapper(env, layout_ids, key_layout), + build_dec_wrapper(env, layout_ids, value_layout), + ), + }; call_void_bitcode_fn( env, @@ -408,8 +415,8 @@ pub fn dict_elements_rc<'a, 'ctx, 'env>( alignment_iv.into(), key_width.into(), value_width.into(), - inc_key_fn.as_global_value().as_pointer_value().into(), - inc_value_fn.as_global_value().as_pointer_value().into(), + key_fn.as_global_value().as_pointer_value().into(), + value_fn.as_global_value().as_pointer_value().into(), ], &bitcode::DICT_ELEMENTS_RC, ); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index b3a759e641..6634c51862 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use crate::llvm::bitcode::{ - build_compare_wrapper, build_dec_wrapper, build_eq_wrapper, build_inc_wrapper, - build_transform_caller, call_bitcode_fn, call_void_bitcode_fn, + build_compare_wrapper, build_dec_wrapper, build_eq_wrapper, build_inc_n_wrapper, + build_inc_wrapper, build_transform_caller, call_bitcode_fn, call_void_bitcode_fn, }; use crate::llvm::build::{ allocate_with_refcount_help, cast_basic_basic, complex_bitcast, Env, InPlace, @@ -118,7 +118,7 @@ pub fn list_repeat<'a, 'ctx, 'env>( element: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - let inc_element_fn = build_inc_wrapper(env, layout_ids, element_layout); + let inc_element_fn = build_inc_n_wrapper(env, layout_ids, element_layout); call_bitcode_fn_returns_list( env, diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index ad6c1d24de..a03398322e 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -109,7 +109,7 @@ impl<'ctx> PointerToRefcount<'ctx> { env: &Env<'a, 'ctx, 'env>, ) { match mode { - CallMode::Inc(_, inc_amount) => self.increment(inc_amount, env), + CallMode::Inc(inc_amount) => self.increment(inc_amount, env), CallMode::Dec => self.decrement(env, layout), } } @@ -315,14 +315,85 @@ impl<'ctx> PointerToRefcount<'ctx> { fn modify_refcount_struct<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, - parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, - value: BasicValueEnum<'ctx>, - layouts: &[Layout<'a>], + layouts: &'a [Layout<'a>], mode: Mode, when_recursive: &WhenRecursive<'a>, +) -> FunctionValue<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let layout = Layout::Struct(layouts); + + let (_, fn_name) = function_name_from_mode( + layout_ids, + &env.interns, + "increment_struct", + "decrement_struct", + &layout, + mode, + ); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let basic_type = basic_type_from_layout(env, &layout); + let function_value = build_header(env, basic_type, mode, &fn_name); + + modify_refcount_struct_help( + env, + layout_ids, + mode, + when_recursive, + layouts, + function_value, + ); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + function +} + +#[allow(clippy::too_many_arguments)] +fn modify_refcount_struct_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + mode: Mode, + when_recursive: &WhenRecursive<'a>, + layouts: &[Layout<'a>], + fn_val: FunctionValue<'ctx>, ) { - let wrapper_struct = value.into_struct_value(); + debug_assert_eq!( + when_recursive, + &WhenRecursive::Unreachable, + "TODO pipe when_recursive through the dict key/value inc/dec" + ); + + let builder = env.builder; + let ctx = env.context; + + // Add a basic block for the entry point + let entry = ctx.append_basic_block(fn_val, "entry"); + + builder.position_at_end(entry); + + debug_info_init!(env, fn_val); + + // Add args to scope + let arg_symbol = Symbol::ARG_1; + let arg_val = fn_val.get_param_iter().next().unwrap(); + + set_name(arg_val, arg_symbol.ident_string(&env.interns)); + + let parent = fn_val; + + let wrapper_struct = arg_val.into_struct_value(); for (i, field_layout) in layouts.iter().enumerate() { if field_layout.contains_refcounted() { @@ -335,13 +406,15 @@ fn modify_refcount_struct<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, field_ptr, field_layout, ); } } + // this function returns void + builder.build_return(None); } pub fn increment_refcount_layout<'a, 'ctx, 'env>( @@ -351,12 +424,24 @@ pub fn increment_refcount_layout<'a, 'ctx, 'env>( inc_amount: u64, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, +) { + let amount = env.ptr_int().const_int(inc_amount, false); + increment_n_refcount_layout(env, parent, layout_ids, amount, value, layout); +} + +pub fn increment_n_refcount_layout<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + layout_ids: &mut LayoutIds<'a>, + amount: IntValue<'ctx>, + value: BasicValueEnum<'ctx>, + layout: &Layout<'a>, ) { modify_refcount_layout( env, parent, layout_ids, - Mode::Inc(inc_amount), + CallMode::Inc(amount), value, layout, ); @@ -369,7 +454,7 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>( value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { - modify_refcount_layout(env, parent, layout_ids, Mode::Dec, value, layout); + modify_refcount_layout(env, parent, layout_ids, CallMode::Dec, value, layout); } fn modify_refcount_builtin<'a, 'ctx, 'env>( @@ -377,37 +462,33 @@ fn modify_refcount_builtin<'a, 'ctx, 'env>( layout_ids: &mut LayoutIds<'a>, mode: Mode, when_recursive: &WhenRecursive<'a>, - value: BasicValueEnum<'ctx>, layout: &Layout<'a>, builtin: &Builtin<'a>, -) { +) -> Option> { use Builtin::*; match builtin { List(memory_mode, element_layout) => { - let wrapper_struct = value.into_struct_value(); - if let MemoryMode::Refcounted = memory_mode { - modify_refcount_list( + let function = modify_refcount_list( env, layout_ids, mode, when_recursive, layout, element_layout, - wrapper_struct, ); + + Some(function) + } else { + None } } Set(element_layout) => { - if element_layout.contains_refcounted() { - // TODO decrement all values - } - todo!(); - } - Dict(key_layout, value_layout) => { - let wrapper_struct = value.into_struct_value(); - modify_refcount_dict( + let key_layout = &Layout::Struct(&[]); + let value_layout = element_layout; + + let function = modify_refcount_dict( env, layout_ids, mode, @@ -415,16 +496,29 @@ fn modify_refcount_builtin<'a, 'ctx, 'env>( layout, key_layout, value_layout, - wrapper_struct, ); + + Some(function) + } + Dict(key_layout, value_layout) => { + let function = modify_refcount_dict( + env, + layout_ids, + mode, + when_recursive, + layout, + key_layout, + value_layout, + ); + + Some(function) } - Str => { - let wrapper_struct = value.into_struct_value(); - modify_refcount_str(env, layout_ids, mode, layout, wrapper_struct); - } + Str => Some(modify_refcount_str(env, layout_ids, mode, layout)), + _ => { debug_assert!(!builtin.is_refcounted()); + None } } } @@ -433,7 +527,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, - mode: Mode, + call_mode: CallMode<'ctx>, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { @@ -441,7 +535,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + call_mode, &WhenRecursive::Unreachable, value, layout, @@ -458,127 +552,29 @@ fn modify_refcount_layout_help<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, layout_ids: &mut LayoutIds<'a>, - mode: Mode, + call_mode: CallMode<'ctx>, when_recursive: &WhenRecursive<'a>, value: BasicValueEnum<'ctx>, layout: &Layout<'a>, ) { - use Layout::*; + let mode = match call_mode { + CallMode::Inc(_) => Mode::Inc, + CallMode::Dec => Mode::Dec, + }; + + let function = match modify_refcount_layout_build_function( + env, + parent, + layout_ids, + mode, + when_recursive, + layout, + ) { + Some(f) => f, + None => return, + }; match layout { - Builtin(builtin) => modify_refcount_builtin( - env, - layout_ids, - mode, - when_recursive, - value, - layout, - builtin, - ), - - Union(variant) => { - use UnionLayout::*; - - match variant { - NullableWrapped { - other_tags: tags, .. - } => { - debug_assert!(value.is_pointer_value()); - - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - tags, - value.into_pointer_value(), - true, - ); - } - - NullableUnwrapped { other_fields, .. } => { - debug_assert!(value.is_pointer_value()); - - let other_fields = &other_fields[1..]; - - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - &*env.arena.alloc([other_fields]), - value.into_pointer_value(), - true, - ); - } - - NonNullableUnwrapped(fields) => { - debug_assert!(value.is_pointer_value()); - - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - &*env.arena.alloc([*fields]), - value.into_pointer_value(), - true, - ); - } - - Recursive(tags) => { - debug_assert!(value.is_pointer_value()); - build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - tags, - value.into_pointer_value(), - false, - ); - } - - NonRecursive(tags) => { - modify_refcount_union(env, layout_ids, mode, when_recursive, tags, value) - } - } - } - Closure(_, closure_layout, _) => { - if closure_layout.contains_refcounted() { - let wrapper_struct = value.into_struct_value(); - - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, 1, "modify_rc_closure_data") - .unwrap(); - - modify_refcount_layout_help( - env, - parent, - layout_ids, - mode, - when_recursive, - field_ptr, - &closure_layout.as_block_of_memory_layout(), - ) - } - } - - Struct(layouts) => { - modify_refcount_struct( - env, - parent, - layout_ids, - value, - layouts, - mode, - when_recursive, - ); - } - - PhantomEmptyStruct => {} - Layout::RecursivePointer => match when_recursive { WhenRecursive::Unreachable => { unreachable!("recursion pointers should never be hashed directly") @@ -594,19 +590,172 @@ fn modify_refcount_layout_help<'a, 'ctx, 'env>( .build_bitcast(value, bt, "i64_to_opaque") .into_pointer_value(); - modify_refcount_layout_help( + call_help(env, function, call_mode, field_cast.into()); + } + }, + _ => { + call_help(env, function, call_mode, value); + } + } +} + +fn call_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + function: FunctionValue<'ctx>, + call_mode: CallMode<'ctx>, + value: BasicValueEnum<'ctx>, +) -> inkwell::values::CallSiteValue<'ctx> { + let call = match call_mode { + CallMode::Inc(inc_amount) => { + env.builder + .build_call(function, &[value, inc_amount.into()], "increment") + } + CallMode::Dec => env.builder.build_call(function, &[value], "decrement"), + }; + + call.set_call_convention(FAST_CALL_CONV); + + call +} + +fn modify_refcount_layout_build_function<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + layout_ids: &mut LayoutIds<'a>, + mode: Mode, + when_recursive: &WhenRecursive<'a>, + layout: &Layout<'a>, +) -> Option> { + use Layout::*; + + match layout { + Builtin(builtin) => { + modify_refcount_builtin(env, layout_ids, mode, when_recursive, layout, builtin) + } + + Union(variant) => { + use UnionLayout::*; + + match variant { + NullableWrapped { + other_tags: tags, .. + } => { + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + tags, + true, + ); + + Some(function) + } + + NullableUnwrapped { other_fields, .. } => { + let other_fields = &other_fields[1..]; + + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + &*env.arena.alloc([other_fields]), + true, + ); + + Some(function) + } + + NonNullableUnwrapped(fields) => { + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + &*env.arena.alloc([*fields]), + true, + ); + Some(function) + } + + Recursive(tags) => { + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + tags, + false, + ); + Some(function) + } + + NonRecursive(tags) => { + let function = + modify_refcount_union(env, layout_ids, mode, when_recursive, tags); + + Some(function) + } + } + } + Closure(argument_layouts, closure_layout, return_layout) => { + if closure_layout.contains_refcounted() { + // Temporary hack to make this work for now. With defunctionalization, none of this + // will matter + let p2 = closure_layout.as_block_of_memory_layout(); + let mut argument_layouts = + Vec::from_iter_in(argument_layouts.iter().copied(), env.arena); + argument_layouts.push(p2); + let argument_layouts = argument_layouts.into_bump_slice(); + + let p1 = Layout::FunctionPointer(argument_layouts, return_layout); + let actual_layout = Layout::Struct(env.arena.alloc([p1, p2])); + + let function = modify_refcount_layout_build_function( + env, + parent, + layout_ids, + mode, + when_recursive, + &actual_layout, + )?; + + Some(function) + } else { + None + } + } + + Struct(layouts) => { + let function = modify_refcount_struct(env, layout_ids, layouts, mode, when_recursive); + + Some(function) + } + + PhantomEmptyStruct => None, + + Layout::RecursivePointer => match when_recursive { + WhenRecursive::Unreachable => { + unreachable!("recursion pointers should never be hashed directly") + } + WhenRecursive::Loop(union_layout) => { + let layout = Layout::Union(*union_layout); + + let function = modify_refcount_layout_build_function( env, parent, layout_ids, mode, when_recursive, - field_cast.into(), &layout, - ) + )?; + + Some(function) } }, - FunctionPointer(_, _) | Pointer(_) => {} + FunctionPointer(_, _) | Pointer(_) => None, } } @@ -617,12 +766,11 @@ fn modify_refcount_list<'a, 'ctx, 'env>( when_recursive: &WhenRecursive<'a>, layout: &Layout<'a>, element_layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { +) -> FunctionValue<'ctx> { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_list", @@ -655,13 +803,13 @@ fn modify_refcount_list<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, original_wrapper.into(), call_name); + function } fn mode_to_call_mode(function: FunctionValue<'_>, mode: Mode) -> CallMode<'_> { match mode { Mode::Dec => CallMode::Dec, - Mode::Inc(num) => CallMode::Inc(num, function.get_nth_param(1).unwrap().into_int_value()), + Mode::Inc => CallMode::Inc(function.get_nth_param(1).unwrap().into_int_value()), } } @@ -720,7 +868,7 @@ fn modify_refcount_list_help<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, element, element_layout, @@ -755,12 +903,11 @@ fn modify_refcount_str<'a, 'ctx, 'env>( layout_ids: &mut LayoutIds<'a>, mode: Mode, layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { +) -> FunctionValue<'ctx> { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_str", @@ -785,7 +932,7 @@ fn modify_refcount_str<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, original_wrapper.into(), call_name); + function } fn modify_refcount_str_help<'a, 'ctx, 'env>( @@ -855,12 +1002,11 @@ fn modify_refcount_dict<'a, 'ctx, 'env>( layout: &Layout<'a>, key_layout: &Layout<'a>, value_layout: &Layout<'a>, - original_wrapper: StructValue<'ctx>, -) { +) -> FunctionValue<'ctx> { let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_dict", @@ -894,7 +1040,7 @@ fn modify_refcount_dict<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, original_wrapper.into(), call_name); + function } #[allow(clippy::too_many_arguments)] @@ -990,7 +1136,7 @@ fn build_header<'a, 'ctx, 'env>( fn_name: &str, ) -> FunctionValue<'ctx> { match mode { - Mode::Inc(_) => build_header_help( + Mode::Inc => build_header_help( env, fn_name, env.context.void_type().into(), @@ -1036,13 +1182,26 @@ pub fn build_header_help<'a, 'ctx, 'env>( #[derive(Clone, Copy)] pub enum Mode { - Inc(u64), + Inc, Dec, } +impl Mode { + fn to_call_mode<'ctx>(&self, function: FunctionValue<'ctx>) -> CallMode<'ctx> { + match self { + Mode::Inc => { + let amount = function.get_nth_param(1).unwrap().into_int_value(); + + CallMode::Inc(amount) + } + Mode::Dec => CallMode::Dec, + } + } +} + #[derive(Clone, Copy)] enum CallMode<'ctx> { - Inc(u64, IntValue<'ctx>), + Inc(IntValue<'ctx>), Dec, } @@ -1052,12 +1211,11 @@ fn build_rec_union<'a, 'ctx, 'env>( mode: Mode, when_recursive: &WhenRecursive<'a>, fields: &'a [&'a [Layout<'a>]], - value: PointerValue<'ctx>, is_nullable: bool, -) { +) -> FunctionValue<'ctx> { let layout = Layout::Union(UnionLayout::Recursive(fields)); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_rec_union", @@ -1095,7 +1253,7 @@ fn build_rec_union<'a, 'ctx, 'env>( } }; - call_help(env, function, mode, value.into(), call_name); + function } fn build_rec_union_help<'a, 'ctx, 'env>( @@ -1112,7 +1270,7 @@ fn build_rec_union_help<'a, 'ctx, 'env>( let context = &env.context; let builder = env.builder; - let pick = |a, b| if let Mode::Inc(_) = mode { a } else { b }; + let pick = |a, b| if let Mode::Inc = mode { a } else { b }; // Add a basic block for the entry point let entry = context.append_basic_block(fn_val, "entry"); @@ -1258,17 +1416,16 @@ fn build_rec_union_help<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, field, field_layout, ); } - let call_name = pick("recursive_tag_increment", "recursive_tag_decrement"); for ptr in deferred_rec { // recursively decrement the field - let call = call_help(env, fn_val, mode, ptr, call_name); + let call = call_help(env, fn_val, mode.to_call_mode(fn_val), ptr); call.set_tail_call(true); } @@ -1331,28 +1488,6 @@ fn rec_union_read_tag<'a, 'ctx, 'env>( .into_int_value() } -fn call_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - function: FunctionValue<'ctx>, - mode: Mode, - value: BasicValueEnum<'ctx>, - call_name: &str, -) -> inkwell::values::CallSiteValue<'ctx> { - let call = match mode { - Mode::Inc(inc_amount) => { - let rc_increment = ptr_int(env.context, env.ptr_bytes).const_int(inc_amount, false); - - env.builder - .build_call(function, &[value, rc_increment.into()], call_name) - } - Mode::Dec => env.builder.build_call(function, &[value], call_name), - }; - - call.set_call_convention(FAST_CALL_CONV); - - call -} - fn function_name_from_mode<'a>( layout_ids: &mut LayoutIds<'a>, interns: &Interns, @@ -1368,7 +1503,7 @@ fn function_name_from_mode<'a>( // rather confusing, so now `inc_x` always corresponds to `dec_x` let layout_id = layout_ids.get(Symbol::DEC, layout); match mode { - Mode::Inc(_) => (if_inc, layout_id.to_symbol_string(Symbol::INC, interns)), + Mode::Inc => (if_inc, layout_id.to_symbol_string(Symbol::INC, interns)), Mode::Dec => (if_dec, layout_id.to_symbol_string(Symbol::DEC, interns)), } } @@ -1379,14 +1514,13 @@ fn modify_refcount_union<'a, 'ctx, 'env>( mode: Mode, when_recursive: &WhenRecursive<'a>, fields: &'a [&'a [Layout<'a>]], - value: BasicValueEnum<'ctx>, -) { +) -> FunctionValue<'ctx> { let layout = Layout::Union(UnionLayout::NonRecursive(fields)); let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let (call_name, fn_name) = function_name_from_mode( + let (_, fn_name) = function_name_from_mode( layout_ids, &env.interns, "increment_union", @@ -1418,7 +1552,7 @@ fn modify_refcount_union<'a, 'ctx, 'env>( env.builder .set_current_debug_location(env.context, di_location); - call_help(env, function, mode, value, call_name); + function } fn modify_refcount_union_help<'a, 'ctx, 'env>( @@ -1509,7 +1643,7 @@ fn modify_refcount_union_help<'a, 'ctx, 'env>( env, parent, layout_ids, - mode, + mode.to_call_mode(fn_val), when_recursive, field_ptr, field_layout, diff --git a/compiler/gen_dev/src/generic64/aarch64.rs b/compiler/gen_dev/src/generic64/aarch64.rs index 251d3aba44..dcfc32474a 100644 --- a/compiler/gen_dev/src/generic64/aarch64.rs +++ b/compiler/gen_dev/src/generic64/aarch64.rs @@ -250,6 +250,16 @@ impl Assembler for AArch64Assembler { unimplemented!("abs_reg64_reg64 is not yet implement for AArch64"); } + #[inline(always)] + fn abs_freg64_freg64( + _buf: &mut Vec<'_, u8>, + _relocs: &mut Vec<'_, Relocation>, + _dst: AArch64FloatReg, + _src: AArch64FloatReg, + ) { + unimplemented!("abs_reg64_reg64 is not yet implement for AArch64"); + } + #[inline(always)] fn add_reg64_reg64_imm32( buf: &mut Vec<'_, u8>, @@ -291,6 +301,16 @@ impl Assembler for AArch64Assembler { unimplemented!("calling functions literal not yet implemented for AArch64"); } + #[inline(always)] + fn imul_reg64_reg64_reg64( + _buf: &mut Vec<'_, u8>, + _dst: AArch64GeneralReg, + _src1: AArch64GeneralReg, + _src2: AArch64GeneralReg, + ) { + unimplemented!("register multiplication not implemented yet for AArch64"); + } + #[inline(always)] fn jmp_imm32(_buf: &mut Vec<'_, u8>, _offset: i32) -> usize { unimplemented!("jump instructions not yet implemented for AArch64"); diff --git a/compiler/gen_dev/src/generic64/mod.rs b/compiler/gen_dev/src/generic64/mod.rs index 672c8678b6..d075c39264 100644 --- a/compiler/gen_dev/src/generic64/mod.rs +++ b/compiler/gen_dev/src/generic64/mod.rs @@ -71,6 +71,12 @@ pub trait CallConv { /// dst should always come before sources. pub trait Assembler { fn abs_reg64_reg64(buf: &mut Vec<'_, u8>, dst: GeneralReg, src: GeneralReg); + fn abs_freg64_freg64( + buf: &mut Vec<'_, u8>, + relocs: &mut Vec<'_, Relocation>, + dst: FloatReg, + src: FloatReg, + ); fn add_reg64_reg64_imm32(buf: &mut Vec<'_, u8>, dst: GeneralReg, src1: GeneralReg, imm32: i32); fn add_freg64_freg64_freg64( @@ -124,6 +130,13 @@ pub trait Assembler { fn mov_stack32_freg64(buf: &mut Vec<'_, u8>, offset: i32, src: FloatReg); fn mov_stack32_reg64(buf: &mut Vec<'_, u8>, offset: i32, src: GeneralReg); + fn imul_reg64_reg64_reg64( + buf: &mut Vec<'_, u8>, + dst: GeneralReg, + src1: GeneralReg, + src2: GeneralReg, + ); + fn sub_reg64_reg64_imm32(buf: &mut Vec<'_, u8>, dst: GeneralReg, src1: GeneralReg, imm32: i32); fn sub_reg64_reg64_reg64( buf: &mut Vec<'_, u8>, @@ -468,6 +481,15 @@ impl< Ok(()) } + fn build_num_abs_f64(&mut self, dst: &Symbol, src: &Symbol) -> Result<(), String> { + let dst_reg = self.claim_float_reg(dst)?; + let src_reg = self.load_to_float_reg(src)?; + + ASM::abs_freg64_freg64(&mut self.buf, &mut self.relocs, dst_reg, src_reg); + + Ok(()) + } + fn build_num_add_i64( &mut self, dst: &Symbol, @@ -494,6 +516,19 @@ impl< Ok(()) } + fn build_num_mul_i64( + &mut self, + dst: &Symbol, + src1: &Symbol, + src2: &Symbol, + ) -> Result<(), String> { + let dst_reg = self.claim_general_reg(dst)?; + let src1_reg = self.load_to_general_reg(src1)?; + let src2_reg = self.load_to_general_reg(src2)?; + ASM::imul_reg64_reg64_reg64(&mut self.buf, dst_reg, src1_reg, src2_reg); + Ok(()) + } + fn build_num_sub_i64( &mut self, dst: &Symbol, diff --git a/compiler/gen_dev/src/generic64/x86_64.rs b/compiler/gen_dev/src/generic64/x86_64.rs index 7a26456f59..f54fb6bf47 100644 --- a/compiler/gen_dev/src/generic64/x86_64.rs +++ b/compiler/gen_dev/src/generic64/x86_64.rs @@ -740,6 +740,24 @@ impl Assembler for X86_64Assembler { cmovl_reg64_reg64(buf, dst, src); } + #[inline(always)] + fn abs_freg64_freg64( + buf: &mut Vec<'_, u8>, + relocs: &mut Vec<'_, Relocation>, + dst: X86_64FloatReg, + src: X86_64FloatReg, + ) { + movsd_freg64_rip_offset32(buf, dst, 0); + + // TODO: make sure this constant only loads once instead of every call to abs + relocs.push(Relocation::LocalData { + offset: buf.len() as u64 - 4, + data: 0x7fffffffffffffffu64.to_le_bytes().to_vec(), + }); + + andpd_freg64_freg64(buf, dst, src); + } + #[inline(always)] fn add_reg64_reg64_imm32( buf: &mut Vec<'_, u8>, @@ -796,6 +814,21 @@ impl Assembler for X86_64Assembler { }); } + #[inline(always)] + fn imul_reg64_reg64_reg64( + buf: &mut Vec<'_, u8>, + dst: X86_64GeneralReg, + src1: X86_64GeneralReg, + src2: X86_64GeneralReg, + ) { + if dst == src1 { + imul_reg64_reg64(buf, dst, src2); + } else { + mov_reg64_reg64(buf, dst, src1); + imul_reg64_reg64(buf, dst, src2); + } + } + #[inline(always)] fn jmp_imm32(buf: &mut Vec<'_, u8>, offset: i32) -> usize { jmp_imm32(buf, offset); @@ -976,6 +1009,21 @@ fn binop_reg64_reg64( buf.extend(&[rex, op_code, 0xC0 + dst_mod + src_mod]); } +#[inline(always)] +fn extended_binop_reg64_reg64( + op_code1: u8, + op_code2: u8, + buf: &mut Vec<'_, u8>, + dst: X86_64GeneralReg, + src: X86_64GeneralReg, +) { + let rex = add_rm_extension(dst, REX_W); + let rex = add_reg_extension(src, rex); + let dst_mod = dst as u8 % 8; + let src_mod = (src as u8 % 8) << 3; + buf.extend(&[rex, op_code1, op_code2, 0xC0 + dst_mod + src_mod]); +} + // Below here are the functions for all of the assembly instructions. // Their names are based on the instruction and operators combined. // You should call `buf.reserve()` if you push or extend more than once. @@ -1018,6 +1066,26 @@ fn addsd_freg64_freg64(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64Fl } } +#[inline(always)] +fn andpd_freg64_freg64(buf: &mut Vec<'_, u8>, dst: X86_64FloatReg, src: X86_64FloatReg) { + let dst_high = dst as u8 > 7; + let dst_mod = dst as u8 % 8; + let src_high = src as u8 > 7; + let src_mod = src as u8 % 8; + + if dst_high || src_high { + buf.extend(&[ + 0x66, + 0x40 + ((dst_high as u8) << 2) + (src_high as u8), + 0x0F, + 0x54, + 0xC0 + (dst_mod << 3) + (src_mod), + ]) + } else { + buf.extend(&[0x66, 0x0F, 0x54, 0xC0 + (dst_mod << 3) + (src_mod)]) + } +} + /// r/m64 AND imm8 (sign-extended). #[inline(always)] fn and_reg64_imm8(buf: &mut Vec<'_, u8>, dst: X86_64GeneralReg, imm: i8) { @@ -1052,6 +1120,14 @@ fn cmp_reg64_reg64(buf: &mut Vec<'_, u8>, dst: X86_64GeneralReg, src: X86_64Gene binop_reg64_reg64(0x39, buf, dst, src); } +/// `IMUL r64,r/m64` -> Signed Multiply r/m64 to r64. +#[inline(always)] +fn imul_reg64_reg64(buf: &mut Vec<'_, u8>, dst: X86_64GeneralReg, src: X86_64GeneralReg) { + // IMUL is strange, the parameters are reversed from must other binary ops. + // The final encoding is (src, dst) instead of (dst, src). + extended_binop_reg64_reg64(0x0F, 0xAF, buf, src, dst); +} + /// Jump near, relative, RIP = RIP + 32-bit displacement sign extended to 64-bits. #[inline(always)] fn jmp_imm32(buf: &mut Vec<'_, u8>, imm: i32) { @@ -1389,6 +1465,35 @@ mod tests { } } + #[test] + fn test_andpd_freg64_freg64() { + let arena = bumpalo::Bump::new(); + let mut buf = bumpalo::vec![in &arena]; + + for ((dst, src), expected) in &[ + ( + (X86_64FloatReg::XMM0, X86_64FloatReg::XMM0), + vec![0x66, 0x0F, 0x54, 0xC0], + ), + ( + (X86_64FloatReg::XMM0, X86_64FloatReg::XMM15), + vec![0x66, 0x41, 0x0F, 0x54, 0xC7], + ), + ( + (X86_64FloatReg::XMM15, X86_64FloatReg::XMM0), + vec![0x66, 0x44, 0x0F, 0x54, 0xF8], + ), + ( + (X86_64FloatReg::XMM15, X86_64FloatReg::XMM15), + vec![0x66, 0x45, 0x0F, 0x54, 0xFF], + ), + ] { + buf.clear(); + andpd_freg64_freg64(&mut buf, *dst, *src); + assert_eq!(&expected[..], &buf[..]); + } + } + #[test] fn test_xor_reg64_reg64() { let arena = bumpalo::Bump::new(); @@ -1460,6 +1565,34 @@ mod tests { } } + #[test] + fn test_imul_reg64_reg64() { + let arena = bumpalo::Bump::new(); + let mut buf = bumpalo::vec![in &arena]; + for ((dst, src), expected) in &[ + ( + (X86_64GeneralReg::RAX, X86_64GeneralReg::RAX), + [0x48, 0x0F, 0xAF, 0xC0], + ), + ( + (X86_64GeneralReg::RAX, X86_64GeneralReg::R15), + [0x49, 0x0F, 0xAF, 0xC7], + ), + ( + (X86_64GeneralReg::R15, X86_64GeneralReg::RAX), + [0x4C, 0x0F, 0xAF, 0xF8], + ), + ( + (X86_64GeneralReg::R15, X86_64GeneralReg::R15), + [0x4D, 0x0F, 0xAF, 0xFF], + ), + ] { + buf.clear(); + imul_reg64_reg64(&mut buf, *dst, *src); + assert_eq!(expected, &buf[..]); + } + } + #[test] fn test_jmp_imm32() { let arena = bumpalo::Bump::new(); diff --git a/compiler/gen_dev/src/lib.rs b/compiler/gen_dev/src/lib.rs index 49126a53a3..8280efe65e 100644 --- a/compiler/gen_dev/src/lib.rs +++ b/compiler/gen_dev/src/lib.rs @@ -184,6 +184,9 @@ where Symbol::NUM_ATAN => { self.build_run_low_level(sym, &LowLevel::NumAtan, arguments, layout) } + Symbol::NUM_MUL => { + self.build_run_low_level(sym, &LowLevel::NumMul, arguments, layout) + } Symbol::NUM_POW_INT => self.build_run_low_level( sym, &LowLevel::NumPowInt, @@ -237,6 +240,7 @@ where // TODO: when this is expanded to floats. deal with typecasting here, and then call correct low level method. match layout { Layout::Builtin(Builtin::Int64) => self.build_num_abs_i64(sym, &args[0]), + Layout::Builtin(Builtin::Float64) => self.build_num_abs_f64(sym, &args[0]), x => Err(format!("layout, {:?}, not implemented yet", x)), } } @@ -261,6 +265,15 @@ where LowLevel::NumAtan => { self.build_fn_call(sym, bitcode::NUM_ATAN.to_string(), args, &[*layout], layout) } + LowLevel::NumMul => { + // TODO: when this is expanded to floats. deal with typecasting here, and then call correct low level method. + match layout { + Layout::Builtin(Builtin::Int64) => { + self.build_num_mul_i64(sym, &args[0], &args[1]) + } + x => Err(format!("layout, {:?}, not implemented yet", x)), + } + } LowLevel::NumPowInt => self.build_fn_call( sym, bitcode::NUM_POW_INT.to_string(), @@ -302,6 +315,10 @@ where /// It only deals with inputs and outputs of i64 type. fn build_num_abs_i64(&mut self, dst: &Symbol, src: &Symbol) -> Result<(), String>; + /// build_num_abs_f64 stores the absolute value of src into dst. + /// It only deals with inputs and outputs of f64 type. + fn build_num_abs_f64(&mut self, dst: &Symbol, src: &Symbol) -> Result<(), String>; + /// build_num_add_i64 stores the sum of src1 and src2 into dst. /// It only deals with inputs and outputs of i64 type. fn build_num_add_i64( @@ -320,6 +337,15 @@ where src2: &Symbol, ) -> Result<(), String>; + /// build_num_mul_i64 stores `src1 * src2` into dst. + /// It only deals with inputs and outputs of i64 type. + fn build_num_mul_i64( + &mut self, + dst: &Symbol, + src1: &Symbol, + src2: &Symbol, + ) -> Result<(), String>; + /// build_num_sub_i64 stores the `src1 - src2` difference into dst. /// It only deals with inputs and outputs of i64 type. fn build_num_sub_i64( diff --git a/compiler/gen_dev/tests/gen_num.rs b/compiler/gen_dev/tests/gen_num.rs index 1c3f6d42b9..80324a3f69 100644 --- a/compiler/gen_dev/tests/gen_num.rs +++ b/compiler/gen_dev/tests/gen_num.rs @@ -75,6 +75,19 @@ mod gen_num { ); } + #[test] + fn gen_mul_i64() { + assert_evals_to!( + indoc!( + r#" + 2 * 4 * 6 + "# + ), + 48, + i64 + ); + } + #[test] fn i64_force_stack() { // This claims 33 registers. One more than Arm and RISC-V, and many more than x86-64. @@ -247,24 +260,6 @@ mod gen_num { -1, i64 ); - - assert_evals_to!( - indoc!( - r#" - limitedNegate = \num -> - if num == 1 then - -1 - else if num == -1 then - 1 - else - num - - limitedNegate 1 - "# - ), - -1, - i64 - ); } #[test] @@ -288,6 +283,12 @@ mod gen_num { ); } + #[test] + fn f64_abs() { + assert_evals_to!("Num.abs -4.7", 4.7, f64); + assert_evals_to!("Num.abs 5.8", 5.8, f64); + } + /* #[test] fn f64_sqrt() { @@ -310,11 +311,7 @@ mod gen_num { assert_evals_to!("Num.round 3.6", 4, i64); } - #[test] - fn f64_abs() { - assert_evals_to!("Num.abs -4.7", 4.7, f64); - assert_evals_to!("Num.abs 5.8", 5.8, f64); - } + #[test] fn gen_float_eq() { @@ -388,32 +385,6 @@ mod gen_num { ); } - #[test] - fn gen_sub_i64() { - assert_evals_to!( - indoc!( - r#" - 1 - 2 - 3 - "# - ), - -4, - i64 - ); - } - - #[test] - fn gen_mul_i64() { - assert_evals_to!( - indoc!( - r#" - 2 * 4 * 6 - "# - ), - 48, - i64 - ); - } - #[test] fn gen_div_i64() { assert_evals_to!( diff --git a/editor/editor-ideas.md b/editor/editor-ideas.md index 92a7a796e9..41ff4dbfff 100644 --- a/editor/editor-ideas.md +++ b/editor/editor-ideas.md @@ -37,10 +37,11 @@ Nice collection of research on innovative editors, [link](https://futureofcoding * [VS code debug visualization](https://marketplace.visualstudio.com/items?itemName=hediet.debug-visualizer) * [Algorithm visualization for javascript](https://algorithm-visualizer.org) * [godbolt.org Compiler Explorer](https://godbolt.org/) -* Say you have a failing test that used to work, it would be very valuable to see all code that was changed that was used only by that test. -e.g. you have a test `calculate_sum_test` that only uses the function `add`, when the test fails you should be able to see a diff showing only what changed for the function `add`. It would also be great to have a diff of [expression values](https://homepages.cwi.nl/~storm/livelit/images/bret.png) Bret Victor style. An ambitious project would be to suggest or automatically try fixes based on these diffs. * [whitebox debug visualization](https://vimeo.com/483795097) * [Hest](https://ivanish.ca/hest-time-travel/) tool for making highly interactive simulations. +* Say you have a failing test that used to work, it would be very valuable to see all code that was changed that was used only by that test. +e.g. you have a test `calculate_sum_test` that only uses the function `add`, when the test fails you should be able to see a diff showing only what changed for the function `add`. It would also be great to have a diff of [expression values](https://homepages.cwi.nl/~storm/livelit/images/bret.png) Bret Victor style. An ambitious project would be to suggest or automatically try fixes based on these diffs. +* I think it could be possible to create a minimal reproduction of a program / block of code / code used by a single test. So for a failing unit test I would expect it to extract imports, the platform, types and functions that are necessary to run only that unit test and put them in a standalone roc project. This would be useful for sharing bugs with library+application authors and colleagues, for profiling or debugging with all "clutter" removed. ### Structured Editing @@ -86,6 +87,12 @@ e.g. you have a test `calculate_sum_test` that only uses the function `add`, whe * Mozilla DeepSpeech model runs fast, works pretty well for actions but would need additional training for code input. Possible to reuse [Mozilla common voice](https://github.com/common-voice/common-voice) for creating more "spoken code" data. +### Beginner-focused Features + + * Show Roc cheat sheet on start-up. + * Plugin that translates short pieces of code from another programming language to Roc. [Relevant research](https://www.youtube.com/watch?v=xTzFJIknh7E). Someone who only knows the R language could get started with Roc with less friction if they could quickly define a list R style (`lst <- c(1,2,3)`) and get it translated to Roc. + * Being able to asses or ask the user for the amount of experience they have with Roc would be a valuable feature for recommending plugins, editor tips, recommending tutorials, automated error search (e.g searching common beginner errors first), ... . + ### Productivity features * When refactoring; @@ -106,6 +113,13 @@ e.g. you have a test `calculate_sum_test` that only uses the function `add`, whe * Regex-like find and substitution based on plain english description and example (replacement). i.e. replace all `[` between double quotes with `{`. [Inspiration](https://alexmoltzau.medium.com/english-to-regex-thanks-to-gpt-3-13f03b68236e). * Show productivity tips based on behavior. i.e. if the user is scrolling through the error bar and clicking on the next error several times, show a tip with "go to next error" shortcut. * Command to "benchmark this function" or "benchmark this test" with flamegraph and execution time per line. +* Instead of going to definition and having to navigate back and forth between files, show an editable view inside the current file. See [this video](https://www.youtube.com/watch?v=EenznqbW5w8) +* When encountering an unexpected error in the user's program we show a button at the bottom to start an automated search on this error. The search would: + * look for similar errors in github issues of the relevant libraries + * search stackoverflow questions + * search a local history of previously encountered errors and fixes + * search through a database of our zullip questions + * ... #### Autocomplete @@ -124,6 +138,8 @@ e.g. you have a test `calculate_sum_test` that only uses the function `add`, whe * [Codota](https://www.codota.com) AI autocomplete and example searching. * [Aroma](https://ai.facebook.com/blog/aroma-ml-for-code-recommendation) showing examples similar to current code. * [MISM](https://arxiv.org/abs/2006.05265) neural network based code similarity scoring. +* [Inquisitive code editor](https://web.eecs.utk.edu/~azh/blog/inquisitivecodeeditor.html) Interactive bug detection with doc+test generation. +* [NextJournal](https://nextjournal.com/joe-loco/command-bar?token=DpU6ewNQnLhYtVkwhs9GeX) Discoverable commands and shortcuts. ### Non-Code Related Inspiration diff --git a/editor/src/lang/constrain.rs b/editor/src/lang/constrain.rs index f9ae771a86..6a940c50f4 100644 --- a/editor/src/lang/constrain.rs +++ b/editor/src/lang/constrain.rs @@ -9,7 +9,7 @@ use crate::lang::{ use roc_can::expected::Expected; use roc_collections::all::{BumpMap, BumpMapDefault, Index}; -use roc_module::symbol::Symbol; +use roc_module::{ident::TagName, symbol::Symbol}; use roc_region::all::{Located, Region}; use roc_types::{ subs::Variable, @@ -21,7 +21,7 @@ use roc_types::{ pub enum Constraint<'a> { Eq(Type2, Expected, Category, Region), // Store(Type, Variable, &'static str, u32), - // Lookup(Symbol, Expected, Region), + Lookup(Symbol, Expected, Region), // Pattern(Region, PatternCategory, Type, PExpected), And(BumpVec<'a, Constraint<'a>>), Let(&'a LetConstraint<'a>), @@ -52,6 +52,7 @@ pub fn constrain_expr<'a>( Expr2::SmallStr(_) => Eq(str_type(env.pool), expected, Category::Str, region), Expr2::Blank => True, Expr2::EmptyRecord => constrain_empty_record(expected, region), + Expr2::Var(symbol) => Lookup(*symbol, expected, region), Expr2::SmallInt { var, .. } => { let mut flex_vars = BumpVec::with_capacity_in(1, arena); @@ -216,6 +217,220 @@ pub fn constrain_expr<'a>( exists(arena, field_vars, And(constraints)) } } + Expr2::GlobalTag { + variant_var, + ext_var, + name, + arguments, + } => { + let mut flex_vars = BumpVec::with_capacity_in(arguments.len(), arena); + let types = PoolVec::with_capacity(arguments.len() as u32, env.pool); + let mut arg_cons = BumpVec::with_capacity_in(arguments.len(), arena); + + for (argument_node_id, type_node_id) in + arguments.iter_node_ids().zip(types.iter_node_ids()) + { + let (var, expr_node_id) = env.pool.get(argument_node_id); + + let argument_expr = env.pool.get(*expr_node_id); + + let arg_con = constrain_expr( + arena, + env, + argument_expr, + Expected::NoExpectation(Type2::Variable(*var)), + region, + ); + + arg_cons.push(arg_con); + flex_vars.push(*var); + + env.pool[type_node_id] = Type2::Variable(*var); + } + + let union_con = Eq( + Type2::TagUnion( + PoolVec::new(std::iter::once((*name, types)), env.pool), + env.pool.add(Type2::Variable(*ext_var)), + ), + expected.shallow_clone(), + Category::TagApply { + tag_name: TagName::Global(name.as_str(env.pool).into()), + args_count: arguments.len(), + }, + region, + ); + + let ast_con = Eq( + Type2::Variable(*variant_var), + expected, + Category::Storage(std::file!(), std::line!()), + region, + ); + + flex_vars.push(*variant_var); + flex_vars.push(*ext_var); + + arg_cons.push(union_con); + arg_cons.push(ast_con); + + exists(arena, flex_vars, And(arg_cons)) + } + Expr2::Call { + args, + expr_var, + expr: expr_node_id, + closure_var, + fn_var, + .. + } => { + // The expression that evaluates to the function being called, e.g. `foo` in + // (foo) bar baz + let expr = env.pool.get(*expr_node_id); + + let opt_symbol = if let Expr2::Var(symbol) = expr { + Some(*symbol) + } else { + None + }; + + let fn_type = Type2::Variable(*fn_var); + let fn_region = region; + let fn_expected = Expected::NoExpectation(fn_type.shallow_clone()); + + let fn_reason = Reason::FnCall { + name: opt_symbol, + arity: args.len() as u8, + }; + + let fn_con = constrain_expr(arena, env, expr, fn_expected, region); + + // The function's return type + // TODO: don't use expr_var? + let ret_type = Type2::Variable(*expr_var); + + // type of values captured in the closure + let closure_type = Type2::Variable(*closure_var); + + // This will be used in the occurs check + let mut vars = BumpVec::with_capacity_in(2 + args.len(), arena); + + vars.push(*fn_var); + // TODO: don't use expr_var? + vars.push(*expr_var); + vars.push(*closure_var); + + let mut arg_types = BumpVec::with_capacity_in(args.len(), arena); + let mut arg_cons = BumpVec::with_capacity_in(args.len(), arena); + + for (index, arg_node_id) in args.iter_node_ids().enumerate() { + let (arg_var, arg) = env.pool.get(arg_node_id); + let arg_expr = env.pool.get(*arg); + + let region = region; + let arg_type = Type2::Variable(*arg_var); + + let reason = Reason::FnArg { + name: opt_symbol, + arg_index: Index::zero_based(index), + }; + + let expected_arg = Expected::ForReason(reason, arg_type.shallow_clone(), region); + + let arg_con = constrain_expr(arena, env, arg_expr, expected_arg, region); + + vars.push(*arg_var); + arg_types.push(arg_type); + arg_cons.push(arg_con); + } + + let expected_fn_type = Expected::ForReason( + fn_reason, + Type2::Function( + PoolVec::new(arg_types.into_iter(), env.pool), + env.pool.add(closure_type), + env.pool.add(ret_type.shallow_clone()), + ), + region, + ); + + let category = Category::CallResult(opt_symbol); + + let mut and_constraints = BumpVec::with_capacity_in(4, arena); + + and_constraints.push(fn_con); + and_constraints.push(Eq(fn_type, expected_fn_type, category.clone(), fn_region)); + and_constraints.push(And(arg_cons)); + and_constraints.push(Eq(ret_type, expected, category, region)); + + exists(arena, vars, And(and_constraints)) + } + Expr2::Accessor { + function_var, + closure_var, + field, + record_var, + ext_var, + field_var, + } => { + let ext_var = *ext_var; + let ext_type = Type2::Variable(ext_var); + + let field_var = *field_var; + let field_type = Type2::Variable(field_var); + + let record_field = + types::RecordField::Demanded(env.pool.add(field_type.shallow_clone())); + + let record_type = Type2::Record( + PoolVec::new(vec![(*field, record_field)].into_iter(), env.pool), + env.pool.add(ext_type), + ); + + let category = Category::Accessor(field.as_str(env.pool).into()); + + let record_expected = Expected::NoExpectation(record_type.shallow_clone()); + let record_con = Eq( + Type2::Variable(*record_var), + record_expected, + category.clone(), + region, + ); + + let function_type = Type2::Function( + PoolVec::new(vec![record_type].into_iter(), env.pool), + env.pool.add(Type2::Variable(*closure_var)), + env.pool.add(field_type), + ); + + let mut flex_vars = BumpVec::with_capacity_in(5, arena); + + flex_vars.push(*record_var); + flex_vars.push(*function_var); + flex_vars.push(*closure_var); + flex_vars.push(field_var); + flex_vars.push(ext_var); + + let mut and_constraints = BumpVec::with_capacity_in(3, arena); + + and_constraints.push(Eq( + function_type.shallow_clone(), + expected, + category.clone(), + region, + )); + + and_constraints.push(Eq( + function_type, + Expected::NoExpectation(Type2::Variable(*function_var)), + category, + region, + )); + + and_constraints.push(record_con); + + exists(arena, flex_vars, And(and_constraints)) + } _ => todo!("implement constaints for {:?}", expr), } } @@ -268,13 +483,7 @@ fn empty_list_type(pool: &mut Pool, var: Variable) -> Type2 { #[inline(always)] fn list_type(pool: &mut Pool, typ: Type2) -> Type2 { - let args = PoolVec::with_capacity(1, pool); - - for (arg_node_id, arg) in args.iter_node_ids().zip(vec![typ]) { - pool[arg_node_id] = arg; - } - - builtin_type(Symbol::LIST_LIST, args) + builtin_type(Symbol::LIST_LIST, PoolVec::new(vec![typ].into_iter(), pool)) } #[inline(always)] diff --git a/editor/src/lang/solve.rs b/editor/src/lang/solve.rs index 3d7118e4f4..7f9d5c7874 100644 --- a/editor/src/lang/solve.rs +++ b/editor/src/lang/solve.rs @@ -1,7 +1,7 @@ #![allow(clippy::all)] #![allow(dead_code)] use crate::lang::constrain::Constraint::{self, *}; -use crate::lang::pool::Pool; +use crate::lang::pool::{Pool, ShallowClone}; use crate::lang::types::Type2; use bumpalo::Bump; use roc_can::expected::{Expected, PExpected}; @@ -270,75 +270,79 @@ fn solve<'a>( // } // } // } - // Lookup(symbol, expectation, region) => { - // match env.vars_by_symbol.get(&symbol) { - // Some(var) => { - // // Deep copy the vars associated with this symbol before unifying them. - // // Otherwise, suppose we have this: - // // - // // identity = \a -> a - // // - // // x = identity 5 - // // - // // When we call (identity 5), it's important that we not unify - // // on identity's original vars. If we do, the type of `identity` will be - // // mutated to be `Int -> Int` instead of `a -> `, which would be incorrect; - // // the type of `identity` is more general than that! - // // - // // Instead, we want to unify on a *copy* of its vars. If the copy unifies - // // successfully (in this case, to `Int -> Int`), we can use that to - // // infer the type of this lookup (in this case, `Int`) without ever - // // having mutated the original. - // // - // // If this Lookup is targeting a value in another module, - // // then we copy from that module's Subs into our own. If the value - // // is being looked up in this module, then we use our Subs as both - // // the source and destination. - // let actual = deep_copy_var(subs, rank, pools, *var); - // let expected = type_to_var( - // subs, - // rank, - // pools, - // cached_aliases, - // expectation.get_type_ref(), - // ); - // match unify(subs, actual, expected) { - // Success(vars) => { - // introduce(subs, rank, pools, &vars); - // - // state - // } - // - // Failure(vars, actual_type, expected_type) => { - // introduce(subs, rank, pools, &vars); - // - // let problem = TypeError::BadExpr( - // *region, - // Category::Lookup(*symbol), - // actual_type, - // expectation.clone().replace(expected_type), - // ); - // - // problems.push(problem); - // - // state - // } - // BadType(vars, problem) => { - // introduce(subs, rank, pools, &vars); - // - // problems.push(TypeError::BadType(problem)); - // - // state - // } - // } - // } - // None => { - // problems.push(TypeError::UnexposedLookup(*symbol)); - // - // state - // } - // } - // } + Lookup(symbol, expectation, region) => { + match env.vars_by_symbol.get(&symbol) { + Some(var) => { + // Deep copy the vars associated with this symbol before unifying them. + // Otherwise, suppose we have this: + // + // identity = \a -> a + // + // x = identity 5 + // + // When we call (identity 5), it's important that we not unify + // on identity's original vars. If we do, the type of `identity` will be + // mutated to be `Int -> Int` instead of `a -> `, which would be incorrect; + // the type of `identity` is more general than that! + // + // Instead, we want to unify on a *copy* of its vars. If the copy unifies + // successfully (in this case, to `Int -> Int`), we can use that to + // infer the type of this lookup (in this case, `Int`) without ever + // having mutated the original. + // + // If this Lookup is targeting a value in another module, + // then we copy from that module's Subs into our own. If the value + // is being looked up in this module, then we use our Subs as both + // the source and destination. + let actual = deep_copy_var(subs, rank, pools, *var); + + let expected = type_to_var( + arena, + mempool, + subs, + rank, + pools, + cached_aliases, + expectation.get_type_ref(), + ); + + match unify(subs, actual, expected) { + Success(vars) => { + introduce(subs, rank, pools, &vars); + + state + } + + Failure(vars, actual_type, expected_type) => { + introduce(subs, rank, pools, &vars); + + let problem = TypeError::BadExpr( + *region, + Category::Lookup(*symbol), + actual_type, + expectation.shallow_clone().replace(expected_type), + ); + + problems.push(problem); + + state + } + BadType(vars, problem) => { + introduce(subs, rank, pools, &vars); + + problems.push(TypeError::BadType(problem)); + + state + } + } + } + None => { + problems.push(TypeError::UnexposedLookup(*symbol)); + + state + } + } + } And(sub_constraints) => { let mut state = state; @@ -826,7 +830,7 @@ fn type_to_variable<'a>( let mut tag_vars = MutMap::default(); let ext = mempool.get(*ext_id); - for (_tag, tag_argument_types) in tags.iter(mempool) { + for (tag, tag_argument_types) in tags.iter(mempool) { let mut tag_argument_vars = Vec::with_capacity(tag_argument_types.len()); for arg_type in tag_argument_types.iter(mempool) { @@ -836,7 +840,7 @@ fn type_to_variable<'a>( } tag_vars.insert( - roc_module::ident::TagName::Private(Symbol::NUM_NUM), + roc_module::ident::TagName::Global(tag.as_str(mempool).into()), tag_argument_vars, ); } @@ -857,22 +861,30 @@ fn type_to_variable<'a>( register(subs, rank, pools, content) } + // This case is important for the rank of boolean variables + Function(args, closure_type_id, ret_type_id) => { + let mut arg_vars = Vec::with_capacity(args.len()); + + let closure_type = mempool.get(*closure_type_id); + let ret_type = mempool.get(*ret_type_id); + + for arg_id in args.iter_node_ids() { + let arg = mempool.get(arg_id); + + arg_vars.push(type_to_variable( + arena, mempool, subs, rank, pools, cached, arg, + )) + } + + let ret_var = type_to_variable(arena, mempool, subs, rank, pools, cached, ret_type); + let closure_var = + type_to_variable(arena, mempool, subs, rank, pools, cached, closure_type); + + let content = Content::Structure(FlatType::Func(arg_vars, closure_var, ret_var)); + + register(subs, rank, pools, content) + } other => todo!("not implemented {:?}", &other), - // - // // This case is important for the rank of boolean variables - // Function(args, closure_type, ret_type) => { - // let mut arg_vars = Vec::with_capacity(args.len()); - // - // for arg in args { - // arg_vars.push(type_to_variable(subs, rank, pools, cached, arg)) - // } - // - // let ret_var = type_to_variable(subs, rank, pools, cached, ret_type); - // let closure_var = type_to_variable(subs, rank, pools, cached, closure_type); - // let content = Content::Structure(FlatType::Func(arg_vars, closure_var, ret_var)); - // - // register(subs, rank, pools, content) - // } // RecursiveTagUnion(rec_var, tags, ext) => { // let mut tag_vars = MutMap::default(); // diff --git a/editor/src/lang/types.rs b/editor/src/lang/types.rs index 5ae8a0e054..d21e74ff50 100644 --- a/editor/src/lang/types.rs +++ b/editor/src/lang/types.rs @@ -62,9 +62,15 @@ impl ShallowClone for Type2 { fn shallow_clone(&self) -> Self { match self { Self::Variable(var) => Self::Variable(*var), - Self::Alias(symbol, pool_vec, type_id) => { - Self::Alias(*symbol, pool_vec.shallow_clone(), type_id.clone()) + Self::Alias(symbol, args, alias_type_id) => { + Self::Alias(*symbol, args.shallow_clone(), alias_type_id.clone()) } + Self::Record(fields, ext_id) => Self::Record(fields.shallow_clone(), ext_id.clone()), + Self::Function(args, closure_type_id, ret_type_id) => Self::Function( + args.shallow_clone(), + closure_type_id.clone(), + ret_type_id.clone(), + ), rest => todo!("{:?}", rest), } } diff --git a/editor/tests/solve_expr2.rs b/editor/tests/solve_expr2.rs index 8faa4bcacb..2c3a7d1299 100644 --- a/editor/tests/solve_expr2.rs +++ b/editor/tests/solve_expr2.rs @@ -238,3 +238,27 @@ fn constrain_list_of_records() { "List { x : Num * }", ) } + +#[test] +fn constrain_global_tag() { + infer_eq( + indoc!( + r#" + Foo + "# + ), + "[ Foo ]*", + ) +} + +#[test] +fn constrain_call_and_accessor() { + infer_eq( + indoc!( + r#" + .foo { foo: "bar" } + "# + ), + "Str", + ) +} diff --git a/roc-for-elm-programmers.md b/roc-for-elm-programmers.md index d9850d63e4..bad83663cd 100644 --- a/roc-for-elm-programmers.md +++ b/roc-for-elm-programmers.md @@ -209,7 +209,7 @@ In Elm: In Roc: ``` -{ x : name : Str, email : Str }* -> Str +{ name : Str, email : Str }* -> Str ``` Here, the open record's type variable appears immediately after the `}`.