diff --git a/cli/tests/cli_run.rs b/cli/tests/cli_run.rs index 5b8ec1acce..63996d3c92 100644 --- a/cli/tests/cli_run.rs +++ b/cli/tests/cli_run.rs @@ -227,6 +227,22 @@ mod cli_run { ); } + + #[test] + #[ignore] + #[serial(astar)] + fn run_astar_optimized_1() { + check_output_with_stdin( + &example_file("benchmarks", "AStarTests.roc"), + "1", + "astar-tests", + &[], + "True\n", + false, + ); + } + + #[ignore] #[test] #[serial(closure1)] fn closure1() { diff --git a/compiler/build/src/program.rs b/compiler/build/src/program.rs index e97c015d09..ee0d990ce5 100644 --- a/compiler/build/src/program.rs +++ b/compiler/build/src/program.rs @@ -97,6 +97,10 @@ pub fn gen_from_mono_module( if name.starts_with("roc_builtins.dict") || name.starts_with("dict.RocDict") { function.add_attribute(AttributeLoc::Function, attr); } + + if name.starts_with("roc_builtins.list") || name.starts_with("list.RocList") { + function.add_attribute(AttributeLoc::Function, attr); + } } let builder = context.create_builder(); diff --git a/compiler/builtins/bitcode/src/dict.zig b/compiler/builtins/bitcode/src/dict.zig index 4580e04176..f6ae1913db 100644 --- a/compiler/builtins/bitcode/src/dict.zig +++ b/compiler/builtins/bitcode/src/dict.zig @@ -5,9 +5,10 @@ const mem = std.mem; const Allocator = mem.Allocator; const assert = std.debug.assert; +const utils = @import("utils.zig"); +const RocList = @import("list.zig").RocList; + const INITIAL_SEED = 0xc70f6907; -const REFCOUNT_ONE_ISIZE: comptime isize = std.math.minInt(isize); -const REFCOUNT_ONE: usize = @bitCast(usize, REFCOUNT_ONE_ISIZE); const InPlace = packed enum(u8) { InPlace, @@ -92,6 +93,23 @@ const Alignment = packed enum(u8) { } }; +pub fn decref( + allocator: *Allocator, + alignment: Alignment, + bytes_or_null: ?[*]u8, + data_bytes: usize, +) void { + return utils.decref(allocator, alignment.toUsize(), bytes_or_null, data_bytes); +} + +pub fn allocateWithRefcount( + allocator: *Allocator, + alignment: Alignment, + data_bytes: usize, +) [*]u8 { + return utils.allocateWithRefcount(allocator, alignment.toUsize(), data_bytes); +} + pub const RocDict = extern struct { dict_bytes: ?[*]u8, dict_entries_len: usize, @@ -211,7 +229,7 @@ pub const RocDict = extern struct { // otherwise, check if the refcount is one const ptr: [*]usize = @ptrCast([*]usize, @alignCast(8, self.dict_bytes)); - return (ptr - 1)[0] == REFCOUNT_ONE; + return (ptr - 1)[0] == utils.REFCOUNT_ONE; } pub fn capacity(self: RocDict) usize { @@ -228,8 +246,6 @@ pub const RocDict = extern struct { } // unfortunately, we have to clone - - const in_place = InPlace.Clone; var new_dict = RocDict.allocate(allocator, self.number_of_levels, self.dict_entries_len, alignment, key_width, value_width); var old_bytes: [*]u8 = @ptrCast([*]u8, self.dict_bytes); @@ -260,6 +276,10 @@ pub const RocDict = extern struct { } fn setKey(self: *RocDict, index: usize, alignment: Alignment, key_width: usize, value_width: usize, data: Opaque) void { + if (key_width == 0) { + return; + } + const offset = blk: { if (alignment.keyFirst()) { break :blk (index * key_width); @@ -289,6 +309,10 @@ pub const RocDict = extern struct { } fn setValue(self: *RocDict, index: usize, alignment: Alignment, key_width: usize, value_width: usize, data: Opaque) void { + if (value_width == 0) { + return; + } + const offset = blk: { if (alignment.keyFirst()) { break :blk (self.capacity() * key_width) + (index * value_width); @@ -518,11 +542,6 @@ pub fn elementsRc(dict: RocDict, alignment: Alignment, key_width: usize, value_w } } -pub const RocList = extern struct { - bytes: ?[*]u8, - length: usize, -}; - pub fn dictKeys(dict: RocDict, alignment: Alignment, key_width: usize, value_width: usize, inc_key: Inc, output: *RocList) callconv(.C) void { const size = dict.capacity(); @@ -538,7 +557,7 @@ pub fn dictKeys(dict: RocDict, alignment: Alignment, key_width: usize, value_wid } if (length == 0) { - output.* = RocList{ .bytes = null, .length = 0 }; + output.* = RocList.empty(); return; } @@ -587,7 +606,7 @@ pub fn dictValues(dict: RocDict, alignment: Alignment, key_width: usize, value_w } if (length == 0) { - output.* = RocList{ .bytes = null, .length = 0 }; + output.* = RocList.empty(); return; } @@ -621,82 +640,131 @@ pub fn dictValues(dict: RocDict, alignment: Alignment, key_width: usize, value_w output.* = RocList{ .bytes = ptr, .length = length }; } -fn decref( - allocator: *Allocator, - alignment: Alignment, - bytes_or_null: ?[*]u8, - data_bytes: usize, -) void { - var bytes = bytes_or_null orelse return; +fn doNothing(ptr: Opaque) callconv(.C) void { + return; +} - const usizes: [*]usize = @ptrCast([*]usize, @alignCast(8, bytes)); +pub fn dictUnion(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, inc_key: Inc, inc_value: Inc, output: *RocDict) callconv(.C) void { + output.* = dict1.makeUnique(std.heap.c_allocator, alignment, key_width, value_width); - const refcount = (usizes - 1)[0]; - const refcount_isize = @bitCast(isize, refcount); + var i: usize = 0; + while (i < dict2.capacity()) : (i += 1) { + switch (dict2.getSlot(i, key_width, value_width)) { + Slot.Filled => { + const key = dict2.getKey(i, alignment, key_width, value_width); - switch (alignment.toUsize()) { - 8 => { - if (refcount == REFCOUNT_ONE) { - allocator.free((bytes - 8)[0 .. 8 + data_bytes]); - } else if (refcount_isize < 0) { - (usizes - 1)[0] = refcount + 1; - } - }, - 16 => { - if (refcount == REFCOUNT_ONE) { - allocator.free((bytes - 16)[0 .. 16 + data_bytes]); - } else if (refcount_isize < 0) { - (usizes - 1)[0] = refcount + 1; - } - }, - else => unreachable, + switch (output.findIndex(alignment, key, key_width, value_width, hash_fn, is_eq)) { + MaybeIndex.not_found => { + const value = dict2.getValue(i, alignment, key_width, value_width); + inc_value(value); + + // we need an extra RC token for the key + inc_key(key); + + const dec_key = doNothing; + const dec_value = doNothing; + + dictInsert(output.*, alignment, key, key_width, value, value_width, hash_fn, is_eq, dec_key, dec_value, output); + }, + MaybeIndex.index => |_| { + // the key is already in the output dict + continue; + }, + } + }, + else => {}, + } } } -fn allocateWithRefcount( - allocator: *Allocator, - alignment: Alignment, - data_bytes: usize, -) [*]u8 { - comptime const result_in_place = InPlace.Clone; +pub fn dictIntersection(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, dec_key: Inc, dec_value: Inc, output: *RocDict) callconv(.C) void { + output.* = dict1.makeUnique(std.heap.c_allocator, alignment, key_width, value_width); - switch (alignment.toUsize()) { - 8 => { - const length = @sizeOf(usize) + data_bytes; + var i: usize = 0; + const size = dict1.capacity(); + while (i < size) : (i += 1) { + switch (output.getSlot(i, key_width, value_width)) { + Slot.Filled => { + const key = dict1.getKey(i, alignment, key_width, value_width); - var new_bytes: []align(8) u8 = allocator.alignedAlloc(u8, 8, length) catch unreachable; - - var as_usize_array = @ptrCast([*]usize, new_bytes); - if (result_in_place == InPlace.InPlace) { - as_usize_array[0] = @intCast(usize, number_of_slots); - } else { - as_usize_array[0] = REFCOUNT_ONE; - } - - var as_u8_array = @ptrCast([*]u8, new_bytes); - const first_slot = as_u8_array + @sizeOf(usize); - - return first_slot; - }, - 16 => { - const length = 2 * @sizeOf(usize) + data_bytes; - - var new_bytes: []align(16) u8 = allocator.alignedAlloc(u8, 16, length) catch unreachable; - - var as_usize_array = @ptrCast([*]usize, new_bytes); - if (result_in_place == InPlace.InPlace) { - as_usize_array[0] = 0; - as_usize_array[1] = @intCast(usize, number_of_slots); - } else { - as_usize_array[0] = 0; - as_usize_array[1] = REFCOUNT_ONE; - } - - var as_u8_array = @ptrCast([*]u8, new_bytes); - const first_slot = as_u8_array + 2 * @sizeOf(usize); - - return first_slot; - }, - else => unreachable, + switch (dict2.findIndex(alignment, key, key_width, value_width, hash_fn, is_eq)) { + MaybeIndex.not_found => { + dictRemove(output.*, alignment, key, key_width, value_width, hash_fn, is_eq, dec_key, dec_value, output); + }, + MaybeIndex.index => |_| { + // keep this key/value + continue; + }, + } + }, + else => {}, + } } } + +pub fn dictDifference(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, dec_key: Inc, dec_value: Inc, output: *RocDict) callconv(.C) void { + output.* = dict1.makeUnique(std.heap.c_allocator, alignment, key_width, value_width); + + var i: usize = 0; + const size = dict1.capacity(); + while (i < size) : (i += 1) { + switch (output.getSlot(i, key_width, value_width)) { + Slot.Filled => { + const key = dict1.getKey(i, alignment, key_width, value_width); + + switch (dict2.findIndex(alignment, key, key_width, value_width, hash_fn, is_eq)) { + MaybeIndex.not_found => { + // keep this key/value + continue; + }, + MaybeIndex.index => |_| { + dictRemove(output.*, alignment, key, key_width, value_width, hash_fn, is_eq, dec_key, dec_value, output); + }, + } + }, + else => {}, + } + } +} + +pub fn setFromList(list: RocList, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, dec_key: Dec, output: *RocDict) callconv(.C) void { + output.* = RocDict.empty(); + + var ptr = @ptrCast([*]u8, list.bytes); + + const dec_value = doNothing; + const value = null; + + const size = list.length; + var i: usize = 0; + while (i < size) : (i += 1) { + const key = ptr + i * key_width; + dictInsert(output.*, alignment, key, key_width, value, value_width, hash_fn, is_eq, dec_key, dec_value, output); + } + + // NOTE: decref checks for the empty case + const data_bytes = size * key_width; + decref(std.heap.c_allocator, alignment, list.bytes, data_bytes); +} + +const StepperCaller = fn (?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void; +pub fn dictWalk(dict: RocDict, stepper: Opaque, stepper_caller: StepperCaller, accum: Opaque, alignment: Alignment, key_width: usize, value_width: usize, accum_width: usize, output: Opaque) callconv(.C) void { + @memcpy(output orelse unreachable, accum orelse unreachable, accum_width); + + var i: usize = 0; + const size = dict.capacity(); + while (i < size) : (i += 1) { + switch (dict.getSlot(i, key_width, value_width)) { + Slot.Filled => { + const key = dict.getKey(i, alignment, key_width, value_width); + const value = dict.getValue(i, alignment, key_width, value_width); + + stepper_caller(stepper, key, value, output, output); + }, + else => {}, + } + } + + const data_bytes = dict.capacity() * slotSize(key_width, value_width); + decref(std.heap.c_allocator, alignment, dict.dict_bytes, data_bytes); +} diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig new file mode 100644 index 0000000000..8ea824fa31 --- /dev/null +++ b/compiler/builtins/bitcode/src/list.zig @@ -0,0 +1,280 @@ +const std = @import("std"); +const utils = @import("utils.zig"); +const RocResult = utils.RocResult; +const mem = std.mem; +const Allocator = mem.Allocator; + +const EqFn = fn (?[*]u8, ?[*]u8) callconv(.C) bool; +const Opaque = ?[*]u8; + +pub const RocList = extern struct { + bytes: ?[*]u8, + length: usize, + + pub fn len(self: RocList) usize { + return self.length; + } + + pub fn isEmpty(self: RocList) bool { + return self.len() == 0; + } + + pub fn empty() RocList { + return RocList{ .bytes = null, .length = 0 }; + } + + pub fn isUnique(self: RocList) bool { + // the empty list is unique (in the sense that copying it will not leak memory) + if (self.isEmpty()) { + return true; + } + + // otherwise, check if the refcount is one + const ptr: [*]usize = @ptrCast([*]usize, @alignCast(8, self.bytes)); + return (ptr - 1)[0] == utils.REFCOUNT_ONE; + } + + pub fn allocate( + allocator: *Allocator, + alignment: usize, + length: usize, + element_size: usize, + ) RocList { + const data_bytes = length * element_size; + + return RocList{ + .bytes = utils.allocateWithRefcount(allocator, alignment, data_bytes), + .length = length, + }; + } + + pub fn makeUnique(self: RocList, allocator: *Allocator, alignment: usize, element_width: usize) RocList { + if (self.isEmpty()) { + return self; + } + + if (self.isUnique()) { + return self; + } + + // unfortunately, we have to clone + var new_list = RocList.allocate(allocator, self.length, alignment, element_width); + + var old_bytes: [*]u8 = @ptrCast([*]u8, self.bytes); + var new_bytes: [*]u8 = @ptrCast([*]u8, new_list.bytes); + + const number_of_bytes = self.len() * element_width; + @memcpy(new_bytes, old_bytes, number_of_bytes); + + // NOTE we fuse an increment of all keys/values with a decrement of the input dict + const data_bytes = self.len() * element_width; + utils.decref(allocator, alignment, self.bytes, data_bytes); + + return new_list; + } + + pub fn reallocate( + self: RocList, + allocator: *Allocator, + alignment: usize, + new_length: usize, + element_width: usize, + ) RocList { + const old_length = self.length; + const delta_length = new_length - old_length; + + const data_bytes = new_capacity * slot_size; + const first_slot = allocateWithRefcount(allocator, alignment, data_bytes); + + // transfer the memory + + if (self.bytes) |source_ptr| { + const dest_ptr = first_slot; + + @memcpy(dest_ptr, source_ptr, old_length); + } + + // NOTE the newly added elements are left uninitialized + + const result = RocList{ + .dict_bytes = first_slot, + .length = new_length, + }; + + // NOTE we fuse an increment of all keys/values with a decrement of the input dict + utils.decref(allocator, alignment, self.bytes, old_length * element_width); + + return result; + } +}; + +const Caller1 = fn (?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void; +const Caller2 = fn (?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void; + +pub fn listMap(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, old_element_width: usize, new_element_width: usize) callconv(.C) RocList { + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = 0; + const output = RocList.allocate(std.heap.c_allocator, alignment, size, new_element_width); + const target_ptr = output.bytes orelse unreachable; + + while (i < size) : (i += 1) { + caller(transform, source_ptr + (i * old_element_width), target_ptr + (i * new_element_width)); + } + + utils.decref(std.heap.c_allocator, alignment, list.bytes, size * old_element_width); + + return output; + } else { + return RocList.empty(); + } +} + +pub fn listMapWithIndex(list: RocList, transform: Opaque, caller: Caller2, alignment: usize, old_element_width: usize, new_element_width: usize) callconv(.C) RocList { + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = 0; + const output = RocList.allocate(std.heap.c_allocator, alignment, size, new_element_width); + const target_ptr = output.bytes orelse unreachable; + + while (i < size) : (i += 1) { + caller(transform, @ptrCast(?[*]u8, &i), source_ptr + (i * old_element_width), target_ptr + (i * new_element_width)); + } + + utils.decref(std.heap.c_allocator, alignment, list.bytes, size * old_element_width); + + return output; + } else { + return RocList.empty(); + } +} + +pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, element_width: usize) callconv(.C) RocList { + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = 0; + var output = RocList.allocate(std.heap.c_allocator, alignment, list.len(), list.len() * element_width); + const target_ptr = output.bytes orelse unreachable; + + var kept: usize = 0; + while (i < size) : (i += 1) { + var keep = false; + const element = source_ptr + (i * element_width); + caller(transform, element, @ptrCast(?[*]u8, &keep)); + + if (keep) { + @memcpy(target_ptr + (kept * element_width), element, element_width); + + kept += 1; + } else { + // TODO decrement the value? + } + } + + output.length = kept; + + utils.decref(std.heap.c_allocator, alignment, list.bytes, size * element_width); + + return output; + } else { + return RocList.empty(); + } +} + +pub fn listKeepOks(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) callconv(.C) RocList { + return listKeepResult(list, RocResult.isOk, transform, caller, alignment, before_width, result_width, after_width); +} + +pub fn listKeepErrs(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) callconv(.C) RocList { + return listKeepResult(list, RocResult.isErr, transform, caller, alignment, before_width, result_width, after_width); +} + +pub fn listKeepResult(list: RocList, is_good_constructor: fn (RocResult) bool, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) RocList { + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = 0; + var output = RocList.allocate(std.heap.c_allocator, alignment, list.len(), list.len() * after_width); + const target_ptr = output.bytes orelse unreachable; + + var temporary = @ptrCast([*]u8, std.heap.c_allocator.alloc(u8, result_width) catch unreachable); + + var kept: usize = 0; + while (i < size) : (i += 1) { + const element = source_ptr + (i * before_width); + caller(transform, element, temporary); + + const result = utils.RocResult{ .bytes = temporary }; + + if (is_good_constructor(result)) { + @memcpy(target_ptr + (kept * after_width), temporary + @sizeOf(i64), after_width); + + kept += 1; + } + } + + output.length = kept; + + utils.decref(std.heap.c_allocator, alignment, list.bytes, size * before_width); + + return output; + } else { + return RocList.empty(); + } +} + +pub fn listWalk(list: RocList, stepper: Opaque, stepper_caller: Caller2, accum: Opaque, alignment: usize, element_width: usize, accum_width: usize, output: Opaque) callconv(.C) void { + if (accum_width == 0) { + return; + } + + @memcpy(output orelse unreachable, accum orelse unreachable, accum_width); + + if (list.bytes) |source_ptr| { + var i: usize = 0; + const size = list.len(); + while (i < size) : (i += 1) { + const element = source_ptr + i * element_width; + stepper_caller(stepper, element, output, output); + } + + const data_bytes = list.len() * element_width; + utils.decref(std.heap.c_allocator, alignment, list.bytes, data_bytes); + } +} + +pub fn listWalkBackwards(list: RocList, stepper: Opaque, stepper_caller: Caller2, accum: Opaque, alignment: usize, element_width: usize, accum_width: usize, output: Opaque) callconv(.C) void { + if (accum_width == 0) { + return; + } + + @memcpy(output orelse unreachable, accum orelse unreachable, accum_width); + + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = size; + while (i > 0) { + i -= 1; + const element = source_ptr + i * element_width; + stepper_caller(stepper, element, output, output); + } + + const data_bytes = list.len() * element_width; + utils.decref(std.heap.c_allocator, alignment, list.bytes, data_bytes); + } +} + +// List.contains : List k, k -> Bool +pub fn listContains(list: RocList, key: Opaque, key_width: usize, is_eq: EqFn) callconv(.C) bool { + if (list.bytes) |source_ptr| { + const size = list.len(); + var i: usize = 0; + while (i < size) : (i += 1) { + const element = source_ptr + i * key_width; + if (is_eq(element, key)) { + return true; + } + } + } + + return false; +} diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index 7bad53bdfa..1ecb482406 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -2,6 +2,20 @@ const builtin = @import("builtin"); const std = @import("std"); const testing = std.testing; +// List Module +const list = @import("list.zig"); + +comptime { + exportListFn(list.listMap, "map"); + exportListFn(list.listMapWithIndex, "map_with_index"); + exportListFn(list.listKeepIf, "keep_if"); + exportListFn(list.listWalk, "walk"); + exportListFn(list.listWalkBackwards, "walk_backwards"); + exportListFn(list.listKeepOks, "keep_oks"); + exportListFn(list.listKeepErrs, "keep_errs"); + exportListFn(list.listContains, "contains"); +} + // Dict Module const dict = @import("dict.zig"); const hash = @import("hash.zig"); @@ -16,6 +30,12 @@ comptime { exportDictFn(dict.elementsRc, "elementsRc"); exportDictFn(dict.dictKeys, "keys"); exportDictFn(dict.dictValues, "values"); + exportDictFn(dict.dictUnion, "union"); + exportDictFn(dict.dictIntersection, "intersection"); + exportDictFn(dict.dictDifference, "difference"); + exportDictFn(dict.dictWalk, "walk"); + + exportDictFn(dict.setFromList, "set_from_list"); exportDictFn(hash.wyhash, "hash"); exportDictFn(hash.wyhash_rocstr, "hash_str"); @@ -43,6 +63,7 @@ comptime { exportStrFn(str.strJoinWithC, "joinWith"); exportStrFn(str.strNumberOfBytes, "number_of_bytes"); exportStrFn(str.strFromIntC, "from_int"); + exportStrFn(str.strFromFloatC, "from_float"); exportStrFn(str.strEqual, "equal"); } @@ -60,6 +81,10 @@ fn exportDictFn(comptime func: anytype, comptime func_name: []const u8) void { exportBuiltinFn(func, "dict." ++ func_name); } +fn exportListFn(comptime func: anytype, comptime func_name: []const u8) void { + exportBuiltinFn(func, "list." ++ func_name); +} + // Run all tests in imported modules // https://github.com/ziglang/zig/blob/master/lib/std/std.zig#L94 test "" { diff --git a/compiler/builtins/bitcode/src/str.zig b/compiler/builtins/bitcode/src/str.zig index 16957835cd..97a4928ee3 100644 --- a/compiler/builtins/bitcode/src/str.zig +++ b/compiler/builtins/bitcode/src/str.zig @@ -302,6 +302,23 @@ fn strFromIntHelp(allocator: *Allocator, comptime T: type, int: T) RocStr { return RocStr.init(allocator, &buf, result.len); } +// Str.fromFloat +// When we actually use this in Roc, libc will be linked so we have access to std.heap.c_allocator +pub fn strFromFloatC(float: f64) callconv(.C) RocStr { + // NOTE the compiled zig for float formatting seems to use LLVM11-specific features + // hopefully we can use zig instead of snprintf in the future when we upgrade + const c = @cImport({ + // See https://github.com/ziglang/zig/issues/515 + @cDefine("_NO_CRT_STDIO_INLINE", "1"); + @cInclude("stdio.h"); + }); + var buf: [100]u8 = undefined; + + const result = c.snprintf(&buf, 100, "%f", float); + + return RocStr.init(std.heap.c_allocator, &buf, @intCast(usize, result)); +} + // Str.split // When we actually use this in Roc, libc will be linked so we have access to std.heap.c_allocator pub fn strSplitInPlaceC(array: [*]RocStr, string: RocStr, delimiter: RocStr) callconv(.C) void { diff --git a/compiler/builtins/bitcode/src/utils.zig b/compiler/builtins/bitcode/src/utils.zig new file mode 100644 index 0000000000..e95f3ab711 --- /dev/null +++ b/compiler/builtins/bitcode/src/utils.zig @@ -0,0 +1,107 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; + +const REFCOUNT_ONE_ISIZE: comptime isize = std.math.minInt(isize); +pub const REFCOUNT_ONE: usize = @bitCast(usize, REFCOUNT_ONE_ISIZE); + +pub fn decref( + allocator: *Allocator, + alignment: usize, + bytes_or_null: ?[*]u8, + data_bytes: usize, +) void { + if (data_bytes == 0) { + return; + } + + var bytes = bytes_or_null orelse return; + + const usizes: [*]usize = @ptrCast([*]usize, @alignCast(8, bytes)); + + const refcount = (usizes - 1)[0]; + const refcount_isize = @bitCast(isize, refcount); + + switch (alignment) { + 16 => { + if (refcount == REFCOUNT_ONE) { + allocator.free((bytes - 16)[0 .. 16 + data_bytes]); + } else if (refcount_isize < 0) { + (usizes - 1)[0] = refcount + 1; + } + }, + else => { + // NOTE enums can currently have an alignment of < 8 + if (refcount == REFCOUNT_ONE) { + allocator.free((bytes - 8)[0 .. 8 + data_bytes]); + } else if (refcount_isize < 0) { + (usizes - 1)[0] = refcount + 1; + } + }, + } +} + +pub fn allocateWithRefcount( + allocator: *Allocator, + alignment: usize, + data_bytes: usize, +) [*]u8 { + comptime const result_in_place = false; + + switch (alignment) { + 16 => { + const length = 2 * @sizeOf(usize) + data_bytes; + + var new_bytes: []align(16) u8 = allocator.alignedAlloc(u8, 16, length) catch unreachable; + + var as_usize_array = @ptrCast([*]usize, new_bytes); + if (result_in_place) { + as_usize_array[0] = 0; + as_usize_array[1] = @intCast(usize, number_of_slots); + } else { + as_usize_array[0] = 0; + as_usize_array[1] = REFCOUNT_ONE; + } + + var as_u8_array = @ptrCast([*]u8, new_bytes); + const first_slot = as_u8_array + 2 * @sizeOf(usize); + + return first_slot; + }, + else => { + const length = @sizeOf(usize) + data_bytes; + + var new_bytes: []align(8) u8 = allocator.alignedAlloc(u8, 8, length) catch unreachable; + + var as_usize_array = @ptrCast([*]usize, new_bytes); + if (result_in_place) { + as_usize_array[0] = @intCast(usize, number_of_slots); + } else { + as_usize_array[0] = REFCOUNT_ONE; + } + + var as_u8_array = @ptrCast([*]u8, new_bytes); + const first_slot = as_u8_array + @sizeOf(usize); + + return first_slot; + }, + } +} + +pub const RocResult = extern struct { + bytes: ?[*]u8, + + pub fn isOk(self: RocResult) bool { + // assumptions + // + // - the tag is the first field + // - the tag is usize bytes wide + // - Ok has tag_id 1, because Err < Ok + const usizes: [*]usize = @ptrCast([*]usize, @alignCast(8, self.bytes)); + + return usizes[0] == 1; + } + + pub fn isErr(self: RocResult) bool { + return !self.isOk(); + } +}; diff --git a/compiler/builtins/src/bitcode.rs b/compiler/builtins/src/bitcode.rs index fc66f4c584..a220151bc9 100644 --- a/compiler/builtins/src/bitcode.rs +++ b/compiler/builtins/src/bitcode.rs @@ -38,6 +38,7 @@ pub const STR_STARTS_WITH: &str = "roc_builtins.str.starts_with"; pub const STR_ENDS_WITH: &str = "roc_builtins.str.ends_with"; pub const STR_NUMBER_OF_BYTES: &str = "roc_builtins.str.number_of_bytes"; pub const STR_FROM_INT: &str = "roc_builtins.str.from_int"; +pub const STR_FROM_FLOAT: &str = "roc_builtins.str.from_float"; pub const STR_EQUAL: &str = "roc_builtins.str.equal"; pub const DICT_HASH: &str = "roc_builtins.dict.hash"; @@ -51,3 +52,18 @@ pub const DICT_GET: &str = "roc_builtins.dict.get"; pub const DICT_ELEMENTS_RC: &str = "roc_builtins.dict.elementsRc"; pub const DICT_KEYS: &str = "roc_builtins.dict.keys"; pub const DICT_VALUES: &str = "roc_builtins.dict.values"; +pub const DICT_UNION: &str = "roc_builtins.dict.union"; +pub const DICT_DIFFERENCE: &str = "roc_builtins.dict.difference"; +pub const DICT_INTERSECTION: &str = "roc_builtins.dict.intersection"; +pub const DICT_WALK: &str = "roc_builtins.dict.walk"; + +pub const SET_FROM_LIST: &str = "roc_builtins.dict.set_from_list"; + +pub const LIST_MAP: &str = "roc_builtins.list.map"; +pub const LIST_MAP_WITH_INDEX: &str = "roc_builtins.list.map_with_index"; +pub const LIST_KEEP_IF: &str = "roc_builtins.list.keep_if"; +pub const LIST_KEEP_OKS: &str = "roc_builtins.list.keep_oks"; +pub const LIST_KEEP_ERRS: &str = "roc_builtins.list.keep_errs"; +pub const LIST_WALK: &str = "roc_builtins.list.walk"; +pub const LIST_WALK_BACKWARDS: &str = "roc_builtins.list.walk_backwards"; +pub const LIST_CONTAINS: &str = "roc_builtins.list.contains"; diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index 7d66ed2d6c..42eb034581 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -10,6 +10,30 @@ use roc_types::solved_types::SolvedType; use roc_types::subs::VarId; use std::collections::HashMap; +/// Example: +/// +/// let_tvars! { a, b, c } +/// +/// This is equivalent to: +/// +/// let a = VarId::from_u32(1); +/// let b = VarId::from_u32(2); +/// let c = VarId::from_u32(3); +/// +/// The idea is that this is less error-prone than assigning hardcoded IDs by hand. +macro_rules! let_tvars { + ($($name:ident,)+) => { let_tvars!($($name),+) }; + ($($name:ident),*) => { + let mut _current_tvar = 0; + + $( + _current_tvar += 1; + + let $name = VarId::from_u32(_current_tvar); + )* + }; +} + #[derive(Clone, Copy, Debug)] pub enum Mode { Standard, @@ -539,6 +563,12 @@ pub fn types() -> MutMap { top_level_function(vec![int_type(flex(TVAR1))], Box::new(str_type())), ); + // fromFloat : Float a -> Str + add_type( + Symbol::STR_FROM_FLOAT, + top_level_function(vec![float_type(flex(TVAR1))], Box::new(str_type())), + ); + // List module // get : List elem, Nat -> Result elem [ OutOfBounds ]* @@ -652,6 +682,38 @@ pub fn types() -> MutMap { ), ); + // keepOks : List before, (before -> Result after *) -> List after + add_type(Symbol::LIST_KEEP_OKS, { + let_tvars! { star, cvar, before, after}; + top_level_function( + vec![ + list_type(flex(before)), + closure( + vec![flex(before)], + cvar, + Box::new(result_type(flex(after), flex(star))), + ), + ], + Box::new(list_type(flex(after))), + ) + }); + + // keepOks : List before, (before -> Result * after) -> List after + add_type(Symbol::LIST_KEEP_ERRS, { + let_tvars! { star, cvar, before, after}; + top_level_function( + vec![ + list_type(flex(before)), + closure( + vec![flex(before)], + cvar, + Box::new(result_type(flex(star), flex(after))), + ), + ], + Box::new(list_type(flex(after))), + ) + }); + // map : List before, (before -> after) -> List after add_type( Symbol::LIST_MAP, @@ -664,6 +726,18 @@ pub fn types() -> MutMap { ), ); + // mapWithIndex : List before, (Nat, before -> after) -> List after + add_type(Symbol::LIST_MAP_WITH_INDEX, { + let_tvars! { cvar, before, after}; + top_level_function( + vec![ + list_type(flex(before)), + closure(vec![nat_type(), flex(before)], cvar, Box::new(flex(after))), + ], + Box::new(list_type(flex(after))), + ) + }); + // append : List elem, elem -> List elem add_type( Symbol::LIST_APPEND, @@ -819,6 +893,59 @@ pub fn types() -> MutMap { ), ); + // Dict.union : Dict k v, Dict k v -> Dict k v + add_type( + Symbol::DICT_UNION, + top_level_function( + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + dict_type(flex(TVAR1), flex(TVAR2)), + ], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), + ), + ); + + // Dict.intersection : Dict k v, Dict k v -> Dict k v + add_type( + Symbol::DICT_INTERSECTION, + top_level_function( + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + dict_type(flex(TVAR1), flex(TVAR2)), + ], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), + ), + ); + + // Dict.difference : Dict k v, Dict k v -> Dict k v + add_type( + Symbol::DICT_DIFFERENCE, + top_level_function( + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + dict_type(flex(TVAR1), flex(TVAR2)), + ], + Box::new(dict_type(flex(TVAR1), flex(TVAR2))), + ), + ); + + // Dict.walk : Dict k v, (k, v, accum -> accum), accum -> accum + add_type( + Symbol::DICT_WALK, + top_level_function( + vec![ + dict_type(flex(TVAR1), flex(TVAR2)), + closure( + vec![flex(TVAR1), flex(TVAR2), flex(TVAR3)], + TVAR4, + Box::new(flex(TVAR3)), + ), + flex(TVAR3), + ], + Box::new(flex(TVAR3)), + ), + ); + // Set module // empty : Set a @@ -830,6 +957,30 @@ pub fn types() -> MutMap { top_level_function(vec![flex(TVAR1)], Box::new(set_type(flex(TVAR1)))), ); + // len : Set * -> Nat + add_type( + Symbol::SET_LEN, + top_level_function(vec![set_type(flex(TVAR1))], Box::new(nat_type())), + ); + + // toList : Set a -> List a + add_type( + Symbol::SET_TO_LIST, + top_level_function( + vec![set_type(flex(TVAR1))], + Box::new(list_type(flex(TVAR1))), + ), + ); + + // fromList : Set a -> List a + add_type( + Symbol::SET_FROM_LIST, + top_level_function( + vec![list_type(flex(TVAR1))], + Box::new(set_type(flex(TVAR1))), + ), + ); + // union : Set a, Set a -> Set a add_type( Symbol::SET_UNION, @@ -839,18 +990,27 @@ pub fn types() -> MutMap { ), ); - // diff : Set a, Set a -> Set a + // difference : Set a, Set a -> Set a add_type( - Symbol::SET_DIFF, + Symbol::SET_DIFFERENCE, top_level_function( vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], Box::new(set_type(flex(TVAR1))), ), ); - // foldl : Set a, (a -> b -> b), b -> b + // intersection : Set a, Set a -> Set a add_type( - Symbol::SET_FOLDL, + Symbol::SET_INTERSECTION, + top_level_function( + vec![set_type(flex(TVAR1)), set_type(flex(TVAR1))], + Box::new(set_type(flex(TVAR1))), + ), + ); + + // Set.walk : Set a, (a, b -> b), b -> b + add_type( + Symbol::SET_WALK, top_level_function( vec![ set_type(flex(TVAR1)), @@ -877,6 +1037,14 @@ pub fn types() -> MutMap { ), ); + add_type( + Symbol::SET_CONTAINS, + top_level_function( + vec![set_type(flex(TVAR1)), flex(TVAR1)], + Box::new(bool_type()), + ), + ); + // Result module // map : Result a err, (a -> b) -> Result b err @@ -891,6 +1059,27 @@ pub fn types() -> MutMap { ), ); + // mapErr : Result a x, (x -> y) -> Result a x + add_type( + Symbol::RESULT_MAP_ERR, + top_level_function( + vec![ + result_type(flex(TVAR1), flex(TVAR3)), + closure(vec![flex(TVAR3)], TVAR4, Box::new(flex(TVAR2))), + ], + Box::new(result_type(flex(TVAR1), flex(TVAR2))), + ), + ); + + // withDefault : Result a x, a -> a + add_type( + Symbol::RESULT_WITH_DEFAULT, + top_level_function( + vec![result_type(flex(TVAR1), flex(TVAR3)), flex(TVAR1)], + Box::new(flex(TVAR1)), + ), + ); + types } diff --git a/compiler/builtins/src/unique.rs b/compiler/builtins/src/unique.rs index f2a4034465..6ffad680e0 100644 --- a/compiler/builtins/src/unique.rs +++ b/compiler/builtins/src/unique.rs @@ -1053,13 +1053,13 @@ pub fn types() -> MutMap { // diff : Attr * (Set * a) // , Attr * (Set * a) // -> Attr * (Set * a) - add_type(Symbol::SET_DIFF, set_combine); + add_type(Symbol::SET_DIFFERENCE, set_combine); // foldl : Attr (* | u) (Set (Attr u a)) // , Attr Shared (Attr u a -> b -> b) // , b // -> b - add_type(Symbol::SET_FOLDL, { + add_type(Symbol::SET_WALK, { let_tvars! { star, u, a, b, closure }; unique_function( diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index c62d0a37c9..2865f31dff 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -1,10 +1,11 @@ use crate::def::Def; use crate::expr::Expr::*; -use crate::expr::{Expr, Recursive}; +use crate::expr::{Expr, Recursive, WhenBranch}; use crate::pattern::Pattern; use roc_collections::all::{MutMap, SendMap}; use roc_module::ident::TagName; use roc_module::low_level::LowLevel; +use roc_module::operator::CalledVia; use roc_module::symbol::Symbol; use roc_region::all::{Located, Region}; use roc_types::subs::{VarStore, Variable}; @@ -60,6 +61,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option STR_ENDS_WITH => str_ends_with, STR_COUNT_GRAPHEMES => str_count_graphemes, STR_FROM_INT => str_from_int, + STR_FROM_FLOAT=> str_from_float, LIST_LEN => list_len, LIST_GET => list_get, LIST_SET => list_set, @@ -76,18 +78,38 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option LIST_PREPEND => list_prepend, LIST_JOIN => list_join, LIST_MAP => list_map, + LIST_MAP_WITH_INDEX => list_map_with_index, LIST_KEEP_IF => list_keep_if, + LIST_KEEP_OKS => list_keep_oks, + LIST_KEEP_ERRS=> list_keep_errs, LIST_WALK => list_walk, LIST_WALK_BACKWARDS => list_walk_backwards, DICT_TEST_HASH => dict_hash_test_only, DICT_LEN => dict_len, DICT_EMPTY => dict_empty, + DICT_SINGLETON => dict_singleton, DICT_INSERT => dict_insert, DICT_REMOVE => dict_remove, DICT_GET => dict_get, DICT_CONTAINS => dict_contains, DICT_KEYS => dict_keys, DICT_VALUES => dict_values, + DICT_UNION=> dict_union, + DICT_INTERSECTION=> dict_intersection, + DICT_DIFFERENCE=> dict_difference, + DICT_WALK=> dict_walk, + SET_EMPTY => set_empty, + SET_LEN => set_len, + SET_SINGLETON => set_singleton, + SET_UNION=> set_union, + SET_INTERSECTION => set_intersection, + SET_DIFFERENCE => set_difference, + SET_TO_LIST => set_to_list, + SET_FROM_LIST => set_from_list, + SET_INSERT => set_insert, + SET_REMOVE => set_remove, + SET_CONTAINS => set_contains, + SET_WALK=> set_walk, NUM_ADD => num_add, NUM_ADD_CHECKED => num_add_checked, NUM_ADD_WRAP => num_add_wrap, @@ -128,7 +150,10 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option NUM_MAX_INT => num_max_int, NUM_MIN_INT => num_min_int, NUM_BITWISE_AND => num_bitwise_and, - NUM_BITWISE_XOR => num_bitwise_xor + NUM_BITWISE_XOR => num_bitwise_xor, + RESULT_MAP => result_map, + RESULT_MAP_ERR => result_map_err, + RESULT_WITH_DEFAULT => result_with_default, } } @@ -164,6 +189,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::STR_ENDS_WITH => str_ends_with, Symbol::STR_COUNT_GRAPHEMES => str_count_graphemes, Symbol::STR_FROM_INT => str_from_int, + Symbol::STR_FROM_FLOAT=> str_from_float, Symbol::LIST_LEN => list_len, Symbol::LIST_GET => list_get, Symbol::LIST_SET => list_set, @@ -180,18 +206,38 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::LIST_PREPEND => list_prepend, Symbol::LIST_JOIN => list_join, Symbol::LIST_MAP => list_map, + Symbol::LIST_MAP_WITH_INDEX => list_map_with_index, Symbol::LIST_KEEP_IF => list_keep_if, + Symbol::LIST_KEEP_OKS => list_keep_oks, + Symbol::LIST_KEEP_ERRS=> list_keep_errs, Symbol::LIST_WALK => list_walk, Symbol::LIST_WALK_BACKWARDS => list_walk_backwards, Symbol::DICT_TEST_HASH => dict_hash_test_only, Symbol::DICT_LEN => dict_len, Symbol::DICT_EMPTY => dict_empty, + Symbol::DICT_SINGLETON => dict_singleton, Symbol::DICT_INSERT => dict_insert, Symbol::DICT_REMOVE => dict_remove, Symbol::DICT_GET => dict_get, Symbol::DICT_CONTAINS => dict_contains, Symbol::DICT_KEYS => dict_keys, Symbol::DICT_VALUES => dict_values, + Symbol::DICT_UNION=> dict_union, + Symbol::DICT_INTERSECTION=> dict_intersection, + Symbol::DICT_DIFFERENCE=> dict_difference, + Symbol::DICT_WALK=> dict_walk, + Symbol::SET_EMPTY => set_empty, + Symbol::SET_LEN => set_len, + Symbol::SET_SINGLETON => set_singleton, + Symbol::SET_UNION=> set_union, + Symbol::SET_INTERSECTION=> set_intersection, + Symbol::SET_DIFFERENCE=> set_difference, + Symbol::SET_TO_LIST => set_to_list, + Symbol::SET_FROM_LIST => set_from_list, + Symbol::SET_INSERT => set_insert, + Symbol::SET_REMOVE => set_remove, + Symbol::SET_CONTAINS => set_contains, + Symbol::SET_WALK => set_walk, Symbol::NUM_ADD => num_add, Symbol::NUM_ADD_CHECKED => num_add_checked, Symbol::NUM_ADD_WRAP => num_add_wrap, @@ -227,9 +273,83 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::NUM_ASIN => num_asin, Symbol::NUM_MAX_INT => num_max_int, Symbol::NUM_MIN_INT => num_min_int, + Symbol::RESULT_MAP => result_map, + Symbol::RESULT_MAP_ERR => result_map_err, + Symbol::RESULT_WITH_DEFAULT => result_with_default, } } +fn lowlevel_1(symbol: Symbol, op: LowLevel, var_store: &mut VarStore) -> Def { + let arg1_var = var_store.fresh(); + let ret_var = var_store.fresh(); + + let body = RunLowLevel { + op, + args: vec![(arg1_var, Var(Symbol::ARG_1))], + ret_var, + }; + + defn( + symbol, + vec![(arg1_var, Symbol::ARG_1)], + var_store, + body, + ret_var, + ) +} + +fn lowlevel_2(symbol: Symbol, op: LowLevel, var_store: &mut VarStore) -> Def { + let arg1_var = var_store.fresh(); + let arg2_var = var_store.fresh(); + let ret_var = var_store.fresh(); + + let body = RunLowLevel { + op, + args: vec![ + (arg1_var, Var(Symbol::ARG_1)), + (arg2_var, Var(Symbol::ARG_2)), + ], + ret_var, + }; + + defn( + symbol, + vec![(arg1_var, Symbol::ARG_1), (arg2_var, Symbol::ARG_2)], + var_store, + body, + ret_var, + ) +} + +fn lowlevel_3(symbol: Symbol, op: LowLevel, var_store: &mut VarStore) -> Def { + let arg1_var = var_store.fresh(); + let arg2_var = var_store.fresh(); + let arg3_var = var_store.fresh(); + let ret_var = var_store.fresh(); + + let body = RunLowLevel { + op, + args: vec![ + (arg1_var, Var(Symbol::ARG_1)), + (arg2_var, Var(Symbol::ARG_2)), + (arg3_var, Var(Symbol::ARG_3)), + ], + ret_var, + }; + + defn( + symbol, + vec![ + (arg1_var, Symbol::ARG_1), + (arg2_var, Symbol::ARG_2), + (arg3_var, Symbol::ARG_3), + ], + var_store, + body, + ret_var, + ) +} + /// Num.maxInt : Int fn num_max_int(symbol: Symbol, var_store: &mut VarStore) -> Def { let int_var = var_store.fresh(); @@ -1372,7 +1492,7 @@ fn str_count_graphemes(symbol: Symbol, var_store: &mut VarStore) -> Def { ) } -/// Str.fromInt : Int -> Str +/// Str.fromInt : Int * -> Str fn str_from_int(symbol: Symbol, var_store: &mut VarStore) -> Def { let int_var = var_store.fresh(); let str_var = var_store.fresh(); @@ -1392,6 +1512,26 @@ fn str_from_int(symbol: Symbol, var_store: &mut VarStore) -> Def { ) } +/// Str.fromFloat : Float * -> Str +fn str_from_float(symbol: Symbol, var_store: &mut VarStore) -> Def { + let float_var = var_store.fresh(); + let str_var = var_store.fresh(); + + let body = RunLowLevel { + op: LowLevel::StrFromFloat, + args: vec![(float_var, Var(Symbol::ARG_1))], + ret_var: str_var, + }; + + defn( + symbol, + vec![(float_var, Symbol::ARG_1)], + var_store, + body, + str_var, + ) +} + /// List.concat : List elem, List elem -> List elem fn list_concat(symbol: Symbol, var_store: &mut VarStore) -> Def { let list_var = var_store.fresh(); @@ -1800,94 +1940,37 @@ fn list_keep_if(symbol: Symbol, var_store: &mut VarStore) -> Def { /// List.contains : List elem, elem -> Bool fn list_contains(symbol: Symbol, var_store: &mut VarStore) -> Def { - let list_var = var_store.fresh(); - let elem_var = var_store.fresh(); - let bool_var = var_store.fresh(); + lowlevel_2(symbol, LowLevel::ListContains, var_store) +} - let body = RunLowLevel { - op: LowLevel::ListContains, - args: vec![ - (list_var, Var(Symbol::ARG_1)), - (elem_var, Var(Symbol::ARG_2)), - ], - ret_var: bool_var, - }; +/// List.keepOks : List before, (before -> Result after *) -> List after +fn list_keep_oks(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::ListKeepOks, var_store) +} - defn( - symbol, - vec![(list_var, Symbol::ARG_1), (elem_var, Symbol::ARG_2)], - var_store, - body, - bool_var, - ) +/// List.keepErrs: List before, (before -> Result * after) -> List after +fn list_keep_errs(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::ListKeepErrs, var_store) } /// List.map : List before, (before -> after) -> List after fn list_map(symbol: Symbol, var_store: &mut VarStore) -> Def { - let list_var = var_store.fresh(); - let func_var = var_store.fresh(); - let ret_list_var = var_store.fresh(); + lowlevel_2(symbol, LowLevel::ListMap, var_store) +} - let body = RunLowLevel { - op: LowLevel::ListMap, - args: vec![ - (list_var, Var(Symbol::ARG_1)), - (func_var, Var(Symbol::ARG_2)), - ], - ret_var: ret_list_var, - }; - - defn( - symbol, - vec![(list_var, Symbol::ARG_1), (func_var, Symbol::ARG_2)], - var_store, - body, - ret_list_var, - ) +/// List.mapWithIndex : List before, (Nat, before -> after) -> List after +fn list_map_with_index(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::ListMapWithIndex, var_store) } /// Dict.hashTestOnly : k, v -> Nat pub fn dict_hash_test_only(symbol: Symbol, var_store: &mut VarStore) -> Def { - let key_var = var_store.fresh(); - let value_var = var_store.fresh(); - let nat_var = var_store.fresh(); - - let body = RunLowLevel { - op: LowLevel::Hash, - args: vec![ - (key_var, Var(Symbol::ARG_1)), - (value_var, Var(Symbol::ARG_2)), - ], - ret_var: nat_var, - }; - - defn( - symbol, - vec![(key_var, Symbol::ARG_1), (value_var, Symbol::ARG_2)], - var_store, - body, - nat_var, - ) + lowlevel_2(symbol, LowLevel::Hash, var_store) } /// Dict.len : Dict * * -> Nat fn dict_len(symbol: Symbol, var_store: &mut VarStore) -> Def { - let size_var = var_store.fresh(); - let dict_var = var_store.fresh(); - - let body = RunLowLevel { - op: LowLevel::DictSize, - args: vec![(dict_var, Var(Symbol::ARG_1))], - ret_var: size_var, - }; - - defn( - symbol, - vec![(dict_var, Symbol::ARG_1)], - var_store, - body, - size_var, - ) + lowlevel_1(symbol, LowLevel::DictSize, var_store) } /// Dict.empty : Dict * * @@ -1908,80 +1991,50 @@ fn dict_empty(symbol: Symbol, var_store: &mut VarStore) -> Def { } } -/// Dict.insert : Dict k v, k, v -> Dict k v -fn dict_insert(symbol: Symbol, var_store: &mut VarStore) -> Def { - let dict_var = var_store.fresh(); +/// Dict.singleton : k, v -> Dict k v +fn dict_singleton(symbol: Symbol, var_store: &mut VarStore) -> Def { let key_var = var_store.fresh(); - let val_var = var_store.fresh(); + let value_var = var_store.fresh(); + let dict_var = var_store.fresh(); + + let empty = RunLowLevel { + op: LowLevel::DictEmpty, + args: vec![], + ret_var: dict_var, + }; let body = RunLowLevel { op: LowLevel::DictInsert, args: vec![ - (dict_var, Var(Symbol::ARG_1)), - (key_var, Var(Symbol::ARG_2)), - (val_var, Var(Symbol::ARG_3)), + (dict_var, empty), + (key_var, Var(Symbol::ARG_1)), + (value_var, Var(Symbol::ARG_2)), ], ret_var: dict_var, }; defn( symbol, - vec![ - (dict_var, Symbol::ARG_1), - (key_var, Symbol::ARG_2), - (val_var, Symbol::ARG_3), - ], + vec![(key_var, Symbol::ARG_1), (value_var, Symbol::ARG_2)], var_store, body, dict_var, ) } +/// Dict.insert : Dict k v, k, v -> Dict k v +fn dict_insert(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_3(symbol, LowLevel::DictInsert, var_store) +} + /// Dict.remove : Dict k v, k -> Dict k v fn dict_remove(symbol: Symbol, var_store: &mut VarStore) -> Def { - let dict_var = var_store.fresh(); - let key_var = var_store.fresh(); - - let body = RunLowLevel { - op: LowLevel::DictRemove, - args: vec![ - (dict_var, Var(Symbol::ARG_1)), - (key_var, Var(Symbol::ARG_2)), - ], - ret_var: dict_var, - }; - - defn( - symbol, - vec![(dict_var, Symbol::ARG_1), (key_var, Symbol::ARG_2)], - var_store, - body, - dict_var, - ) + lowlevel_2(symbol, LowLevel::DictRemove, var_store) } /// Dict.contains : Dict k v, k -> Bool fn dict_contains(symbol: Symbol, var_store: &mut VarStore) -> Def { - let dict_var = var_store.fresh(); - let key_var = var_store.fresh(); - let bool_var = var_store.fresh(); - - let body = RunLowLevel { - op: LowLevel::DictContains, - args: vec![ - (dict_var, Var(Symbol::ARG_1)), - (key_var, Var(Symbol::ARG_2)), - ], - ret_var: bool_var, - }; - - defn( - symbol, - vec![(dict_var, Symbol::ARG_1), (key_var, Symbol::ARG_2)], - var_store, - body, - bool_var, - ) + lowlevel_2(symbol, LowLevel::DictContains, var_store) } /// Dict.get : Dict k v, k -> Result v [ KeyNotFound ]* @@ -2067,41 +2120,208 @@ fn dict_get(symbol: Symbol, var_store: &mut VarStore) -> Def { /// Dict.keys : Dict k v -> List k fn dict_keys(symbol: Symbol, var_store: &mut VarStore) -> Def { - let dict_var = var_store.fresh(); - let list_var = var_store.fresh(); - - let body = RunLowLevel { - op: LowLevel::DictKeys, - args: vec![(dict_var, Var(Symbol::ARG_1))], - ret_var: list_var, - }; - - defn( - symbol, - vec![(dict_var, Symbol::ARG_1)], - var_store, - body, - list_var, - ) + lowlevel_1(symbol, LowLevel::DictKeys, var_store) } /// Dict.values : Dict k v -> List v fn dict_values(symbol: Symbol, var_store: &mut VarStore) -> Def { - let dict_var = var_store.fresh(); - let list_var = var_store.fresh(); + lowlevel_1(symbol, LowLevel::DictValues, var_store) +} + +/// Dict.union : Dict k v, Dict k v -> Dict k v +fn dict_union(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::DictUnion, var_store) +} + +/// Dict.difference : Dict k v, Dict k v -> Dict k v +fn dict_difference(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::DictDifference, var_store) +} + +/// Dict.intersection : Dict k v, Dict k v -> Dict k v +fn dict_intersection(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::DictIntersection, var_store) +} + +/// Dict.walk : Dict k v, (k, v, accum -> accum), accum -> accum +fn dict_walk(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_3(symbol, LowLevel::DictWalk, var_store) +} + +/// Set.empty : Set * +fn set_empty(symbol: Symbol, var_store: &mut VarStore) -> Def { + let set_var = var_store.fresh(); + let body = RunLowLevel { + op: LowLevel::DictEmpty, + args: vec![], + ret_var: set_var, + }; + + Def { + annotation: None, + expr_var: set_var, + loc_expr: Located::at_zero(body), + loc_pattern: Located::at_zero(Pattern::Identifier(symbol)), + pattern_vars: SendMap::default(), + } +} + +/// Set.singleton : k -> Set k +fn set_singleton(symbol: Symbol, var_store: &mut VarStore) -> Def { + let key_var = var_store.fresh(); + let set_var = var_store.fresh(); + let value_var = Variable::EMPTY_RECORD; + + let empty = RunLowLevel { + op: LowLevel::DictEmpty, + args: vec![], + ret_var: set_var, + }; let body = RunLowLevel { - op: LowLevel::DictValues, - args: vec![(dict_var, Var(Symbol::ARG_1))], - ret_var: list_var, + op: LowLevel::DictInsert, + args: vec![ + (set_var, empty), + (key_var, Var(Symbol::ARG_1)), + (value_var, EmptyRecord), + ], + ret_var: set_var, }; defn( symbol, - vec![(dict_var, Symbol::ARG_1)], + vec![(key_var, Symbol::ARG_1)], var_store, body, - list_var, + set_var, + ) +} + +/// Set.len : Set * -> Nat +fn set_len(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_1(symbol, LowLevel::DictSize, var_store) +} + +/// Dict.union : Dict k v, Dict k v -> Dict k v +fn set_union(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::DictUnion, var_store) +} + +/// Dict.difference : Dict k v, Dict k v -> Dict k v +fn set_difference(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::DictDifference, var_store) +} + +/// Dict.intersection : Dict k v, Dict k v -> Dict k v +fn set_intersection(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_2(symbol, LowLevel::DictIntersection, var_store) +} + +/// Set.toList : Set k -> List k +fn set_to_list(symbol: Symbol, var_store: &mut VarStore) -> Def { + dict_keys(symbol, var_store) +} + +/// Set.fromList : List k -> Set k +fn set_from_list(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_1(symbol, LowLevel::SetFromList, var_store) +} + +/// Set.insert : Set k, k -> Set k +fn set_insert(symbol: Symbol, var_store: &mut VarStore) -> Def { + let dict_var = var_store.fresh(); + let key_var = var_store.fresh(); + let val_var = Variable::EMPTY_RECORD; + + let body = RunLowLevel { + op: LowLevel::DictInsert, + args: vec![ + (dict_var, Var(Symbol::ARG_1)), + (key_var, Var(Symbol::ARG_2)), + (val_var, EmptyRecord), + ], + ret_var: dict_var, + }; + + defn( + symbol, + vec![(dict_var, Symbol::ARG_1), (key_var, Symbol::ARG_2)], + var_store, + body, + dict_var, + ) +} + +/// Set.remove : Set k, k -> Set k +fn set_remove(symbol: Symbol, var_store: &mut VarStore) -> Def { + dict_remove(symbol, var_store) +} + +/// Set.remove : Set k, k -> Set k +fn set_contains(symbol: Symbol, var_store: &mut VarStore) -> Def { + dict_contains(symbol, var_store) +} + +/// Set.walk : Set k, (k, accum -> accum), accum -> accum +fn set_walk(symbol: Symbol, var_store: &mut VarStore) -> Def { + let dict_var = var_store.fresh(); + let func_var = var_store.fresh(); + let key_var = var_store.fresh(); + let accum_var = var_store.fresh(); + let wrapper_var = var_store.fresh(); + + let user_function = Box::new(( + func_var, + no_region(Var(Symbol::ARG_2)), + var_store.fresh(), + accum_var, + )); + + let call_func = Call( + user_function, + vec![ + (key_var, no_region(Var(Symbol::ARG_5))), + (accum_var, no_region(Var(Symbol::ARG_6))), + ], + CalledVia::Space, + ); + + let wrapper = Closure { + function_type: wrapper_var, + closure_type: var_store.fresh(), + closure_ext_var: var_store.fresh(), + return_type: accum_var, + name: Symbol::SET_WALK_USER_FUNCTION, + recursive: Recursive::NotRecursive, + captured_symbols: vec![(Symbol::ARG_2, func_var)], + arguments: vec![ + (key_var, no_region(Pattern::Identifier(Symbol::ARG_5))), + (Variable::EMPTY_RECORD, no_region(Pattern::Underscore)), + (accum_var, no_region(Pattern::Identifier(Symbol::ARG_6))), + ], + loc_body: Box::new(no_region(call_func)), + }; + + let body = RunLowLevel { + op: LowLevel::DictWalk, + args: vec![ + (dict_var, Var(Symbol::ARG_1)), + (wrapper_var, wrapper), + (accum_var, Var(Symbol::ARG_3)), + ], + ret_var: accum_var, + }; + + defn( + symbol, + vec![ + (dict_var, Symbol::ARG_1), + (func_var, Symbol::ARG_2), + (accum_var, Symbol::ARG_3), + ], + var_store, + body, + accum_var, ) } @@ -2511,6 +2731,263 @@ fn list_last(symbol: Symbol, var_store: &mut VarStore) -> Def { ) } +fn result_map(symbol: Symbol, var_store: &mut VarStore) -> Def { + let ret_var = var_store.fresh(); + let func_var = var_store.fresh(); + let result_var = var_store.fresh(); + + let mut branches = vec![]; + + { + let user_function = Box::new(( + func_var, + no_region(Var(Symbol::ARG_2)), + var_store.fresh(), + var_store.fresh(), + )); + + let call_func = Call( + user_function, + vec![(var_store.fresh(), no_region(Var(Symbol::ARG_5)))], + CalledVia::Space, + ); + + let tag_name = TagName::Global("Ok".into()); + + // ok branch + let ok = Tag { + variant_var: var_store.fresh(), + ext_var: var_store.fresh(), + name: tag_name.clone(), + arguments: vec![(var_store.fresh(), no_region(call_func))], + }; + + let pattern = Pattern::AppliedTag { + whole_var: result_var, + ext_var: var_store.fresh(), + tag_name, + arguments: vec![( + var_store.fresh(), + no_region(Pattern::Identifier(Symbol::ARG_5)), + )], + }; + + let branch = WhenBranch { + patterns: vec![no_region(pattern)], + value: no_region(ok), + guard: None, + }; + + branches.push(branch); + } + + { + // err branch + let tag_name = TagName::Global("Err".into()); + + let err = Tag { + variant_var: var_store.fresh(), + ext_var: var_store.fresh(), + name: tag_name.clone(), + arguments: vec![(var_store.fresh(), no_region(Var(Symbol::ARG_4)))], + }; + + let pattern = Pattern::AppliedTag { + whole_var: result_var, + ext_var: var_store.fresh(), + tag_name, + arguments: vec![( + var_store.fresh(), + no_region(Pattern::Identifier(Symbol::ARG_4)), + )], + }; + + let branch = WhenBranch { + patterns: vec![no_region(pattern)], + value: no_region(err), + guard: None, + }; + + branches.push(branch); + } + + let body = When { + cond_var: result_var, + expr_var: ret_var, + region: Region::zero(), + loc_cond: Box::new(no_region(Var(Symbol::ARG_1))), + branches, + }; + + defn( + symbol, + vec![(result_var, Symbol::ARG_1), (func_var, Symbol::ARG_2)], + var_store, + body, + ret_var, + ) +} + +fn result_map_err(symbol: Symbol, var_store: &mut VarStore) -> Def { + let ret_var = var_store.fresh(); + let func_var = var_store.fresh(); + let result_var = var_store.fresh(); + + let mut branches = vec![]; + + { + let user_function = Box::new(( + func_var, + no_region(Var(Symbol::ARG_2)), + var_store.fresh(), + var_store.fresh(), + )); + + let call_func = Call( + user_function, + vec![(var_store.fresh(), no_region(Var(Symbol::ARG_5)))], + CalledVia::Space, + ); + + let tag_name = TagName::Global("Err".into()); + + // ok branch + let ok = Tag { + variant_var: var_store.fresh(), + ext_var: var_store.fresh(), + name: tag_name.clone(), + arguments: vec![(var_store.fresh(), no_region(call_func))], + }; + + let pattern = Pattern::AppliedTag { + whole_var: result_var, + ext_var: var_store.fresh(), + tag_name, + arguments: vec![( + var_store.fresh(), + no_region(Pattern::Identifier(Symbol::ARG_5)), + )], + }; + + let branch = WhenBranch { + patterns: vec![no_region(pattern)], + value: no_region(ok), + guard: None, + }; + + branches.push(branch); + } + + { + // err branch + let tag_name = TagName::Global("Ok".into()); + + let err = Tag { + variant_var: var_store.fresh(), + ext_var: var_store.fresh(), + name: tag_name.clone(), + arguments: vec![(var_store.fresh(), no_region(Var(Symbol::ARG_4)))], + }; + + let pattern = Pattern::AppliedTag { + whole_var: result_var, + ext_var: var_store.fresh(), + tag_name, + arguments: vec![( + var_store.fresh(), + no_region(Pattern::Identifier(Symbol::ARG_4)), + )], + }; + + let branch = WhenBranch { + patterns: vec![no_region(pattern)], + value: no_region(err), + guard: None, + }; + + branches.push(branch); + } + + let body = When { + cond_var: result_var, + expr_var: ret_var, + region: Region::zero(), + loc_cond: Box::new(no_region(Var(Symbol::ARG_1))), + branches, + }; + + defn( + symbol, + vec![(result_var, Symbol::ARG_1), (func_var, Symbol::ARG_2)], + var_store, + body, + ret_var, + ) +} + +fn result_with_default(symbol: Symbol, var_store: &mut VarStore) -> Def { + let ret_var = var_store.fresh(); + let result_var = var_store.fresh(); + + let mut branches = vec![]; + + { + // ok branch + let tag_name = TagName::Global("Ok".into()); + + let pattern = Pattern::AppliedTag { + whole_var: result_var, + ext_var: var_store.fresh(), + tag_name, + arguments: vec![(ret_var, no_region(Pattern::Identifier(Symbol::ARG_3)))], + }; + + let branch = WhenBranch { + patterns: vec![no_region(pattern)], + value: no_region(Var(Symbol::ARG_3)), + guard: None, + }; + + branches.push(branch); + } + + { + // err branch + let tag_name = TagName::Global("Err".into()); + + let pattern = Pattern::AppliedTag { + whole_var: result_var, + ext_var: var_store.fresh(), + tag_name, + arguments: vec![(var_store.fresh(), no_region(Pattern::Underscore))], + }; + + let branch = WhenBranch { + patterns: vec![no_region(pattern)], + value: no_region(Var(Symbol::ARG_2)), + guard: None, + }; + + branches.push(branch); + } + + let body = When { + cond_var: result_var, + expr_var: ret_var, + region: Region::zero(), + loc_cond: Box::new(no_region(Var(Symbol::ARG_1))), + branches, + }; + + defn( + symbol, + vec![(result_var, Symbol::ARG_1), (ret_var, Symbol::ARG_2)], + var_store, + body, + ret_var, + ) +} + #[inline(always)] fn no_region(value: T) -> Located { Located { diff --git a/compiler/gen/src/llvm/bitcode.rs b/compiler/gen/src/llvm/bitcode.rs index bfd6557365..bd494641d0 100644 --- a/compiler/gen/src/llvm/bitcode.rs +++ b/compiler/gen/src/llvm/bitcode.rs @@ -1,21 +1,376 @@ -use inkwell::types::BasicTypeEnum; -use roc_module::low_level::LowLevel; +use crate::debug_info_init; +use crate::llvm::build::{set_name, Env, FAST_CALL_CONV}; +use crate::llvm::convert::basic_type_from_layout; +use crate::llvm::refcounting::{decrement_refcount_layout, increment_refcount_layout, Mode}; +use inkwell::attributes::{Attribute, AttributeLoc}; +/// Helpers for interacting with the zig that generates bitcode +use inkwell::types::{BasicType, BasicTypeEnum}; +use inkwell::values::{BasicValueEnum, CallSiteValue, FunctionValue, InstructionValue}; +use inkwell::AddressSpace; +use roc_module::symbol::Symbol; +use roc_mono::layout::{Layout, LayoutIds}; pub fn call_bitcode_fn<'a, 'ctx, 'env>( - op: LowLevel, env: &Env<'a, 'ctx, 'env>, args: &[BasicValueEnum<'ctx>], fn_name: &str, ) -> BasicValueEnum<'ctx> { + call_bitcode_fn_help(env, args, fn_name) + .try_as_basic_value() + .left() + .unwrap_or_else(|| { + panic!( + "LLVM error: Did not get return value from bitcode function {:?}", + fn_name + ) + }) +} + +pub fn call_void_bitcode_fn<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + args: &[BasicValueEnum<'ctx>], + fn_name: &str, +) -> InstructionValue<'ctx> { + call_bitcode_fn_help(env, args, fn_name) + .try_as_basic_value() + .right() + .unwrap_or_else(|| panic!("LLVM error: Tried to call void bitcode function, but got return value from bitcode function, {:?}", fn_name)) +} + +fn call_bitcode_fn_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + args: &[BasicValueEnum<'ctx>], + fn_name: &str, +) -> CallSiteValue<'ctx> { let fn_val = env - .module - .get_function(fn_name) - .unwrap_or_else(|| panic!("Unrecognized builtin function: {:?} - if you're working on the Roc compiler, do you need to rebuild the bitcode? See compiler/builtins/bitcode/README.md", fn_name)); + .module + .get_function(fn_name) + .unwrap_or_else(|| panic!("Unrecognized builtin function: {:?} - if you're working on the Roc compiler, do you need to rebuild the bitcode? See compiler/builtins/bitcode/README.md", fn_name)); + let call = env.builder.build_call(fn_val, args, "call_builtin"); call.set_call_convention(fn_val.get_call_conventions()); - - call.try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call for low-level op {:?}", op)) + call +} + +const ARGUMENT_SYMBOLS: [Symbol; 8] = [ + Symbol::ARG_1, + Symbol::ARG_2, + Symbol::ARG_3, + Symbol::ARG_4, + Symbol::ARG_5, + Symbol::ARG_6, + Symbol::ARG_7, + Symbol::ARG_8, +]; + +pub fn build_transform_caller<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + function_layout: &Layout<'a>, + argument_layouts: &[Layout<'a>], +) -> FunctionValue<'ctx> { + let symbol = Symbol::ZIG_FUNCTION_CALLER; + let fn_name = layout_ids + .get(symbol, &function_layout) + .to_symbol_string(symbol, &env.interns); + + match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => build_transform_caller_help(env, function_layout, argument_layouts, &fn_name), + } +} + +fn build_transform_caller_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + function_layout: &Layout<'a>, + argument_layouts: &[Layout<'a>], + fn_name: &str, +) -> FunctionValue<'ctx> { + debug_assert!(argument_layouts.len() <= 7); + + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.void_type().into(), + &(bumpalo::vec![ in env.arena; BasicTypeEnum::PointerType(arg_type); argument_layouts.len() + 2 ]), + ); + + let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); + debug_assert!(kind_id > 0); + let attr = env.context.create_enum_attribute(kind_id, 1); + function_value.add_attribute(AttributeLoc::Function, attr); + + let entry = env.context.append_basic_block(function_value, "entry"); + env.builder.position_at_end(entry); + + debug_info_init!(env, function_value); + + let mut it = function_value.get_param_iter(); + let closure_ptr = it.next().unwrap().into_pointer_value(); + set_name(closure_ptr.into(), Symbol::ARG_1.ident_string(&env.interns)); + + let arguments = + bumpalo::collections::Vec::from_iter_in(it.take(argument_layouts.len()), env.arena); + + for (argument, name) in arguments.iter().zip(ARGUMENT_SYMBOLS[1..].iter()) { + set_name(*argument, name.ident_string(&env.interns)); + } + + let closure_type = + basic_type_from_layout(env.arena, env.context, function_layout, env.ptr_bytes) + .ptr_type(AddressSpace::Generic); + + let mut arguments_cast = + bumpalo::collections::Vec::with_capacity_in(arguments.len(), env.arena); + + for (argument_ptr, layout) in arguments.iter().zip(argument_layouts) { + let basic_type = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes) + .ptr_type(AddressSpace::Generic); + + let argument_cast = env + .builder + .build_bitcast(*argument_ptr, basic_type, "load_opaque") + .into_pointer_value(); + + let argument = env.builder.build_load(argument_cast, "load_opaque"); + + arguments_cast.push(argument); + } + + let closure_cast = env + .builder + .build_bitcast(closure_ptr, closure_type, "load_opaque") + .into_pointer_value(); + + let fpointer = env.builder.build_load(closure_cast, "load_opaque"); + + let call = match function_layout { + Layout::FunctionPointer(_, _) => env.builder.build_call( + fpointer.into_pointer_value(), + arguments_cast.as_slice(), + "tmp", + ), + Layout::Closure(_, _, _) | Layout::Struct(_) => { + let pair = fpointer.into_struct_value(); + + let fpointer = env + .builder + .build_extract_value(pair, 0, "get_fpointer") + .unwrap(); + + let closure_data = env + .builder + .build_extract_value(pair, 1, "get_closure_data") + .unwrap(); + + arguments_cast.push(closure_data); + env.builder.build_call( + fpointer.into_pointer_value(), + arguments_cast.as_slice(), + "tmp", + ) + } + _ => unreachable!("layout is not callable {:?}", function_layout), + }; + call.set_call_convention(FAST_CALL_CONV); + + let result = call + .try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")); + + let result_u8_ptr = function_value + .get_nth_param(argument_layouts.len() as u32 + 1) + .unwrap(); + let result_ptr = env + .builder + .build_bitcast( + result_u8_ptr, + result.get_type().ptr_type(AddressSpace::Generic), + "write_result", + ) + .into_pointer_value(); + + env.builder.build_store(result_ptr, result); + env.builder.build_return(None); + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + function_value +} + +pub fn build_inc_wrapper<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + layout: &Layout<'a>, +) -> FunctionValue<'ctx> { + build_rc_wrapper(env, layout_ids, layout, Mode::Inc(1)) +} + +pub fn build_dec_wrapper<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + layout: &Layout<'a>, +) -> FunctionValue<'ctx> { + build_rc_wrapper(env, layout_ids, layout, Mode::Dec) +} + +pub fn build_rc_wrapper<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + layout: &Layout<'a>, + rc_operation: Mode, +) -> FunctionValue<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let symbol = Symbol::GENERIC_RC_REF; + let fn_name = layout_ids + .get(symbol, &layout) + .to_symbol_string(symbol, &env.interns); + + let fn_name = match rc_operation { + Mode::Inc(n) => format!("{}_inc_{}", fn_name, n), + Mode::Dec => format!("{}_dec", fn_name), + }; + + let function_value = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.void_type().into(), + &[arg_type.into()], + ); + + let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); + debug_assert!(kind_id > 0); + let attr = env.context.create_enum_attribute(kind_id, 1); + function_value.add_attribute(AttributeLoc::Function, attr); + + let entry = env.context.append_basic_block(function_value, "entry"); + env.builder.position_at_end(entry); + + debug_info_init!(env, function_value); + + let mut it = function_value.get_param_iter(); + let value_ptr = it.next().unwrap().into_pointer_value(); + + set_name(value_ptr.into(), Symbol::ARG_1.ident_string(&env.interns)); + + let value_type = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes) + .ptr_type(AddressSpace::Generic); + + let value_cast = env + .builder + .build_bitcast(value_ptr, value_type, "load_opaque") + .into_pointer_value(); + + let value = env.builder.build_load(value_cast, "load_opaque"); + + match rc_operation { + Mode::Inc(n) => { + increment_refcount_layout(env, function_value, layout_ids, n, value, layout); + } + Mode::Dec => { + decrement_refcount_layout(env, function_value, layout_ids, value, layout); + } + } + + env.builder.build_return(None); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + function_value +} + +pub fn build_eq_wrapper<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + layout: &Layout<'a>, +) -> FunctionValue<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let symbol = Symbol::GENERIC_EQ_REF; + let fn_name = layout_ids + .get(symbol, &layout) + .to_symbol_string(symbol, &env.interns); + + let function_value = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.bool_type().into(), + &[arg_type.into(), arg_type.into()], + ); + + let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); + debug_assert!(kind_id > 0); + let attr = env.context.create_enum_attribute(kind_id, 1); + function_value.add_attribute(AttributeLoc::Function, attr); + + let entry = env.context.append_basic_block(function_value, "entry"); + env.builder.position_at_end(entry); + + debug_info_init!(env, function_value); + + let mut it = function_value.get_param_iter(); + let value_ptr1 = it.next().unwrap().into_pointer_value(); + let value_ptr2 = it.next().unwrap().into_pointer_value(); + + set_name(value_ptr1.into(), Symbol::ARG_1.ident_string(&env.interns)); + set_name(value_ptr2.into(), Symbol::ARG_2.ident_string(&env.interns)); + + let value_type = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes) + .ptr_type(AddressSpace::Generic); + + let value_cast1 = env + .builder + .build_bitcast(value_ptr1, value_type, "load_opaque") + .into_pointer_value(); + + let value_cast2 = env + .builder + .build_bitcast(value_ptr2, value_type, "load_opaque") + .into_pointer_value(); + + let value1 = env.builder.build_load(value_cast1, "load_opaque"); + let value2 = env.builder.build_load(value_cast2, "load_opaque"); + + let result = + crate::llvm::compare::generic_eq(env, layout_ids, value1, value2, layout, layout); + + env.builder.build_return(Some(&result)); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + function_value } diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 6ba4296318..df7e65c4a0 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -1,14 +1,17 @@ +use crate::llvm::bitcode::call_bitcode_fn; use crate::llvm::build_dict::{ - dict_contains, dict_empty, dict_get, dict_insert, dict_keys, dict_len, dict_remove, dict_values, + dict_contains, dict_difference, dict_empty, dict_get, dict_insert, dict_intersection, + dict_keys, dict_len, dict_remove, dict_union, dict_values, dict_walk, set_from_list, }; use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, - list_get_unsafe, list_join, list_keep_if, list_len, list_map, list_prepend, list_repeat, - list_reverse, list_set, list_single, list_sum, list_walk, list_walk_backwards, + list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, + list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, list_single, list_sum, + list_walk, list_walk_backwards, }; use crate::llvm::build_str::{ - str_concat, str_count_graphemes, str_ends_with, str_from_int, str_join_with, + str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_join_with, str_number_of_bytes, str_split, str_starts_with, CHAR_LAYOUT, }; use crate::llvm::compare::{generic_eq, generic_neq}; @@ -35,8 +38,8 @@ use inkwell::passes::{PassManager, PassManagerBuilder}; use inkwell::types::{BasicTypeEnum, FunctionType, IntType, StructType}; use inkwell::values::BasicValueEnum::{self, *}; use inkwell::values::{ - BasicValue, CallSiteValue, FloatValue, FunctionValue, InstructionOpcode, InstructionValue, - IntValue, PointerValue, StructValue, + BasicValue, CallSiteValue, FloatValue, FunctionValue, InstructionOpcode, IntValue, + PointerValue, StructValue, }; use inkwell::OptimizationLevel; use inkwell::{AddressSpace, IntPredicate}; @@ -753,8 +756,8 @@ pub fn build_exp_call<'a, 'ctx, 'env>( } non_ptr => { panic!( - "Tried to call by pointer, but encountered a non-pointer: {:?}", - non_ptr + "Tried to call by pointer, but encountered a non-pointer: {:?} {:?} {:?}", + name, non_ptr, full_layout ); } }; @@ -1850,6 +1853,7 @@ fn invoke_roc_function<'a, 'ctx, 'env>( layout: Layout<'a>, function_value: Either, PointerValue<'ctx>>, arguments: &[Symbol], + closure_argument: Option>, pass: &'a roc_mono::ir::Stmt<'a>, fail: &'a roc_mono::ir::Stmt<'a>, ) -> BasicValueEnum<'ctx> { @@ -1860,6 +1864,7 @@ fn invoke_roc_function<'a, 'ctx, 'env>( for arg in arguments.iter() { arg_vals.push(load_symbol(scope, arg)); } + arg_vals.extend(closure_argument); let pass_block = context.append_basic_block(parent, "invoke_pass"); let fail_block = context.append_basic_block(parent, "invoke_fail"); @@ -2019,6 +2024,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( layout.clone(), function_value.into(), call.arguments, + None, pass, fail, ) @@ -2026,28 +2032,57 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( CallType::ByPointer { name, .. } => { let sub_expr = load_symbol(scope, &name); - let function_ptr = match sub_expr { - BasicValueEnum::PointerValue(ptr) => ptr, + match sub_expr { + BasicValueEnum::PointerValue(function_ptr) => { + // basic call by pointer + invoke_roc_function( + env, + layout_ids, + scope, + parent, + *symbol, + layout.clone(), + function_ptr.into(), + call.arguments, + None, + pass, + fail, + ) + } + BasicValueEnum::StructValue(ptr_and_data) => { + // this is a closure + let builder = env.builder; + + let function_ptr = builder + .build_extract_value(ptr_and_data, 0, "function_ptr") + .unwrap() + .into_pointer_value(); + + let closure_data = builder + .build_extract_value(ptr_and_data, 1, "closure_data") + .unwrap(); + + invoke_roc_function( + env, + layout_ids, + scope, + parent, + *symbol, + layout.clone(), + function_ptr.into(), + call.arguments, + Some(closure_data), + pass, + fail, + ) + } non_ptr => { panic!( "Tried to call by pointer, but encountered a non-pointer: {:?}", non_ptr ); } - }; - - invoke_roc_function( - env, - layout_ids, - scope, - parent, - *symbol, - layout.clone(), - function_ptr.into(), - call.arguments, - pass, - fail, - ) + } } CallType::Foreign { ref foreign_symbol, @@ -2203,9 +2238,14 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( let (value, layout) = load_symbol_and_layout(scope, symbol); if layout.is_refcounted() { - let value_ptr = value.into_pointer_value(); - let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); - refcount_ptr.decrement(env, layout); + if value.is_pointer_value() { + // BasicValueEnum::PointerValue(value_ptr) => { + let value_ptr = value.into_pointer_value(); + let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); + refcount_ptr.decrement(env, layout); + } else { + eprint!("we're likely leaking memory; see issue #985 for details"); + } } build_exp_stmt(env, layout_ids, scope, parent, cont) @@ -3497,6 +3537,12 @@ fn run_low_level<'a, 'ctx, 'env>( str_from_int(env, scope, args[0]) } + StrFromFloat => { + // Str.fromFloat : Float * -> Str + debug_assert_eq!(args.len(), 1); + + str_from_float(env, scope, args[0]) + } StrSplit => { // Str.split : Str, Str -> List Str debug_assert_eq!(args.len(), 2); @@ -3582,18 +3628,29 @@ fn run_low_level<'a, 'ctx, 'env>( let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); - let inplace = get_inplace_from_layout(layout); + match list_layout { + Layout::Builtin(Builtin::EmptyList) => empty_list(env), + Layout::Builtin(Builtin::List(_, element_layout)) => { + list_map(env, layout_ids, func, func_layout, list, element_layout) + } + _ => unreachable!("invalid list layout"), + } + } + ListMapWithIndex => { + // List.map : List before, (before -> after) -> List after + debug_assert_eq!(args.len(), 2); - list_map( - env, - layout_ids, - inplace, - parent, - func, - func_layout, - list, - list_layout, - ) + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + + match list_layout { + Layout::Builtin(Builtin::EmptyList) => empty_list(env), + Layout::Builtin(Builtin::List(_, element_layout)) => { + list_map_with_index(env, layout_ids, func, func_layout, list, element_layout) + } + _ => unreachable!("invalid list layout"), + } } ListKeepIf => { // List.keepIf : List elem, (elem -> Bool) -> List elem @@ -3603,36 +3660,79 @@ fn run_low_level<'a, 'ctx, 'env>( let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); - let inplace = get_inplace_from_layout(layout); + match list_layout { + Layout::Builtin(Builtin::EmptyList) => empty_list(env), + Layout::Builtin(Builtin::List(_, element_layout)) => { + list_keep_if(env, layout_ids, func, func_layout, list, element_layout) + } + _ => unreachable!("invalid list layout"), + } + } + ListKeepOks => { + // List.keepOks : List before, (before -> Result after *) -> List after + debug_assert_eq!(args.len(), 2); - list_keep_if( - env, - layout_ids, - inplace, - parent, - func, - func_layout, - list, - list_layout, - ) + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + + match (list_layout, layout) { + (_, Layout::Builtin(Builtin::EmptyList)) + | (Layout::Builtin(Builtin::EmptyList), _) => empty_list(env), + ( + Layout::Builtin(Builtin::List(_, before_layout)), + Layout::Builtin(Builtin::List(_, after_layout)), + ) => list_keep_oks( + env, + layout_ids, + func, + func_layout, + list, + before_layout, + after_layout, + ), + (other1, other2) => { + unreachable!("invalid list layouts:\n{:?}\n{:?}", other1, other2) + } + } + } + ListKeepErrs => { + // List.keepErrs : List before, (before -> Result * after) -> List after + debug_assert_eq!(args.len(), 2); + + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + let (func, func_layout) = load_symbol_and_layout(scope, &args[1]); + + match (list_layout, layout) { + (_, Layout::Builtin(Builtin::EmptyList)) + | (Layout::Builtin(Builtin::EmptyList), _) => empty_list(env), + ( + Layout::Builtin(Builtin::List(_, before_layout)), + Layout::Builtin(Builtin::List(_, after_layout)), + ) => list_keep_errs( + env, + layout_ids, + func, + func_layout, + list, + before_layout, + after_layout, + ), + (other1, other2) => { + unreachable!("invalid list layouts:\n{:?}\n{:?}", other1, other2) + } + } } ListContains => { // List.contains : List elem, elem -> Bool debug_assert_eq!(args.len(), 2); - let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + let list = load_symbol(scope, &args[0]); let (elem, elem_layout) = load_symbol_and_layout(scope, &args[1]); - list_contains( - env, - layout_ids, - parent, - elem, - elem_layout, - list, - list_layout, - ) + list_contains(env, layout_ids, elem, elem_layout, list) } ListWalk => { debug_assert_eq!(args.len(), 3); @@ -3643,16 +3743,21 @@ fn run_low_level<'a, 'ctx, 'env>( let (default, default_layout) = load_symbol_and_layout(scope, &args[2]); - list_walk( - env, - parent, - list, - list_layout, - func, - func_layout, - default, - default_layout, - ) + match list_layout { + Layout::Builtin(Builtin::EmptyList) => default, + Layout::Builtin(Builtin::List(_, element_layout)) => list_walk( + env, + layout_ids, + parent, + list, + element_layout, + func, + func_layout, + default, + default_layout, + ), + _ => unreachable!("invalid list layout"), + } } ListWalkBackwards => { // List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum @@ -3664,16 +3769,21 @@ fn run_low_level<'a, 'ctx, 'env>( let (default, default_layout) = load_symbol_and_layout(scope, &args[2]); - list_walk_backwards( - env, - parent, - list, - list_layout, - func, - func_layout, - default, - default_layout, - ) + match list_layout { + Layout::Builtin(Builtin::EmptyList) => default, + Layout::Builtin(Builtin::List(_, element_layout)) => list_walk_backwards( + env, + layout_ids, + parent, + list, + element_layout, + func, + func_layout, + default, + default_layout, + ), + _ => unreachable!("invalid list layout"), + } } ListSum => { debug_assert_eq!(args.len(), 1); @@ -4047,7 +4157,7 @@ fn run_low_level<'a, 'ctx, 'env>( match dict_layout { Layout::Builtin(Builtin::EmptyDict) => { // no elements, so `key` is not in here - panic!("key type unknown") + empty_list(env) } Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { dict_keys(env, layout_ids, dict, key_layout, value_layout) @@ -4063,7 +4173,7 @@ fn run_low_level<'a, 'ctx, 'env>( match dict_layout { Layout::Builtin(Builtin::EmptyDict) => { // no elements, so `key` is not in here - panic!("key type unknown") + empty_list(env) } Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { dict_values(env, layout_ids, dict, key_layout, value_layout) @@ -4071,6 +4181,96 @@ fn run_low_level<'a, 'ctx, 'env>( _ => unreachable!("invalid dict layout"), } } + DictUnion => { + debug_assert_eq!(args.len(), 2); + + let (dict1, dict_layout) = load_symbol_and_layout(scope, &args[0]); + let (dict2, _) = load_symbol_and_layout(scope, &args[1]); + + match dict_layout { + Layout::Builtin(Builtin::EmptyDict) => { + // no elements, so `key` is not in here + panic!("key type unknown") + } + Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { + dict_union(env, layout_ids, dict1, dict2, key_layout, value_layout) + } + _ => unreachable!("invalid dict layout"), + } + } + DictDifference => { + debug_assert_eq!(args.len(), 2); + + let (dict1, dict_layout) = load_symbol_and_layout(scope, &args[0]); + let (dict2, _) = load_symbol_and_layout(scope, &args[1]); + + match dict_layout { + Layout::Builtin(Builtin::EmptyDict) => { + // no elements, so `key` is not in here + panic!("key type unknown") + } + Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { + dict_difference(env, layout_ids, dict1, dict2, key_layout, value_layout) + } + _ => unreachable!("invalid dict layout"), + } + } + DictIntersection => { + debug_assert_eq!(args.len(), 2); + + let (dict1, dict_layout) = load_symbol_and_layout(scope, &args[0]); + let (dict2, _) = load_symbol_and_layout(scope, &args[1]); + + match dict_layout { + Layout::Builtin(Builtin::EmptyDict) => { + // no elements, so `key` is not in here + panic!("key type unknown") + } + Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { + dict_intersection(env, layout_ids, dict1, dict2, key_layout, value_layout) + } + _ => unreachable!("invalid dict layout"), + } + } + DictWalk => { + debug_assert_eq!(args.len(), 3); + + let (dict, dict_layout) = load_symbol_and_layout(scope, &args[0]); + let (stepper, stepper_layout) = load_symbol_and_layout(scope, &args[1]); + let (accum, accum_layout) = load_symbol_and_layout(scope, &args[2]); + + match dict_layout { + Layout::Builtin(Builtin::EmptyDict) => { + // no elements, so `key` is not in here + panic!("key type unknown") + } + Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => dict_walk( + env, + layout_ids, + dict, + stepper, + accum, + stepper_layout, + key_layout, + value_layout, + accum_layout, + ), + _ => unreachable!("invalid dict layout"), + } + } + SetFromList => { + debug_assert_eq!(args.len(), 1); + + let (list, list_layout) = load_symbol_and_layout(scope, &args[0]); + + match list_layout { + Layout::Builtin(Builtin::EmptyList) => dict_empty(env, scope), + Layout::Builtin(Builtin::List(_, key_layout)) => { + set_from_list(env, layout_ids, list, key_layout) + } + _ => unreachable!("invalid dict layout"), + } + } } } @@ -4393,49 +4593,6 @@ fn build_int_binop<'a, 'ctx, 'env>( } } -pub fn call_bitcode_fn<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - args: &[BasicValueEnum<'ctx>], - fn_name: &str, -) -> BasicValueEnum<'ctx> { - call_bitcode_fn_help(env, args, fn_name) - .try_as_basic_value() - .left() - .unwrap_or_else(|| { - panic!( - "LLVM error: Did not get return value from bitcode function {:?}", - fn_name - ) - }) -} - -pub fn call_void_bitcode_fn<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - args: &[BasicValueEnum<'ctx>], - fn_name: &str, -) -> InstructionValue<'ctx> { - call_bitcode_fn_help(env, args, fn_name) - .try_as_basic_value() - .right() - .unwrap_or_else(|| panic!("LLVM error: Tried to call void bitcode function, but got return value from bitcode function, {:?}", fn_name)) -} - -fn call_bitcode_fn_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - args: &[BasicValueEnum<'ctx>], - fn_name: &str, -) -> CallSiteValue<'ctx> { - let fn_val = env - .module - .get_function(fn_name) - .unwrap_or_else(|| panic!("Unrecognized builtin function: {:?} - if you're working on the Roc compiler, do you need to rebuild the bitcode? See compiler/builtins/bitcode/README.md", fn_name)); - - let call = env.builder.build_call(fn_val, args, "call_builtin"); - - call.set_call_convention(fn_val.get_call_conventions()); - call -} - pub fn build_num_binop<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, diff --git a/compiler/gen/src/llvm/build_dict.rs b/compiler/gen/src/llvm/build_dict.rs index 3b8cd5feed..8aa6f17887 100644 --- a/compiler/gen/src/llvm/build_dict.rs +++ b/compiler/gen/src/llvm/build_dict.rs @@ -1,10 +1,13 @@ use crate::debug_info_init; +use crate::llvm::bitcode::{ + build_dec_wrapper, build_eq_wrapper, build_inc_wrapper, build_transform_caller, + call_bitcode_fn, call_void_bitcode_fn, +}; use crate::llvm::build::{ - call_bitcode_fn, call_void_bitcode_fn, complex_bitcast, load_symbol, load_symbol_and_layout, - set_name, Env, Scope, + complex_bitcast, load_symbol, load_symbol_and_layout, set_name, Env, Scope, }; use crate::llvm::convert::{self, as_const_zero, basic_type_from_layout, collection}; -use crate::llvm::refcounting::{decrement_refcount_layout, increment_refcount_layout, Mode}; +use crate::llvm::refcounting::Mode; use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::types::BasicType; use inkwell::values::{BasicValueEnum, FunctionValue, StructValue}; @@ -128,8 +131,8 @@ pub fn dict_insert<'a, 'ctx, 'env>( let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); - let dec_key_fn = build_rc_wrapper(env, layout_ids, key_layout, Mode::Dec); - let dec_value_fn = build_rc_wrapper(env, layout_ids, value_layout, Mode::Dec); + let dec_key_fn = build_dec_wrapper(env, layout_ids, key_layout); + let dec_value_fn = build_dec_wrapper(env, layout_ids, value_layout); call_void_bitcode_fn( env, @@ -198,8 +201,8 @@ pub fn dict_remove<'a, 'ctx, 'env>( let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); - let dec_key_fn = build_rc_wrapper(env, layout_ids, key_layout, Mode::Dec); - let dec_value_fn = build_rc_wrapper(env, layout_ids, value_layout, Mode::Dec); + let dec_key_fn = build_dec_wrapper(env, layout_ids, key_layout); + let dec_value_fn = build_dec_wrapper(env, layout_ids, value_layout); call_void_bitcode_fn( env, @@ -315,7 +318,7 @@ pub fn dict_get<'a, 'ctx, 'env>( let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); - let inc_value_fn = build_rc_wrapper(env, layout_ids, value_layout, Mode::Inc(1)); + let inc_value_fn = build_inc_wrapper(env, layout_ids, value_layout); // { flag: bool, value: *const u8 } let result = call_bitcode_fn( @@ -422,6 +425,7 @@ pub fn dict_elements_rc<'a, 'ctx, 'env>( let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); + use crate::llvm::bitcode::build_rc_wrapper; let inc_key_fn = build_rc_wrapper(env, layout_ids, key_layout, rc_operation); let inc_value_fn = build_rc_wrapper(env, layout_ids, value_layout, rc_operation); @@ -450,7 +454,7 @@ pub fn dict_keys<'a, 'ctx, 'env>( let builder = env.builder; let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); - let zig_list_type = env.module.get_struct_type("dict.RocList").unwrap(); + let zig_list_type = env.module.get_struct_type("list.RocList").unwrap(); let dict_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); env.builder @@ -467,7 +471,7 @@ pub fn dict_keys<'a, 'ctx, 'env>( let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); - let inc_key_fn = build_rc_wrapper(env, layout_ids, key_layout, Mode::Inc(1)); + let inc_key_fn = build_inc_wrapper(env, layout_ids, key_layout); let list_ptr = builder.build_alloca(zig_list_type, "list_ptr"); @@ -496,6 +500,272 @@ pub fn dict_keys<'a, 'ctx, 'env>( env.builder.build_load(list_ptr, "load_keys_list") } +#[allow(clippy::too_many_arguments)] +pub fn dict_union<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + dict1: BasicValueEnum<'ctx>, + dict2: BasicValueEnum<'ctx>, + key_layout: &Layout<'a>, + value_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); + + let dict1_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); + let dict2_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); + + env.builder.build_store( + dict1_ptr, + struct_to_zig_dict(env, dict1.into_struct_value()), + ); + + env.builder.build_store( + dict2_ptr, + struct_to_zig_dict(env, dict2.into_struct_value()), + ); + + let key_width = env + .ptr_int() + .const_int(key_layout.stack_size(env.ptr_bytes) as u64, false); + + let value_width = env + .ptr_int() + .const_int(value_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); + let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); + + let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); + let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); + + let inc_key_fn = build_inc_wrapper(env, layout_ids, key_layout); + let inc_value_fn = build_inc_wrapper(env, layout_ids, value_layout); + + let output_ptr = builder.build_alloca(zig_dict_type, "output_ptr"); + + call_void_bitcode_fn( + env, + &[ + dict1_ptr.into(), + dict2_ptr.into(), + alignment_iv.into(), + key_width.into(), + value_width.into(), + hash_fn.as_global_value().as_pointer_value().into(), + eq_fn.as_global_value().as_pointer_value().into(), + inc_key_fn.as_global_value().as_pointer_value().into(), + inc_value_fn.as_global_value().as_pointer_value().into(), + output_ptr.into(), + ], + &bitcode::DICT_UNION, + ); + + let output_ptr = env + .builder + .build_bitcast( + output_ptr, + convert::dict(env.context, env.ptr_bytes).ptr_type(AddressSpace::Generic), + "to_roc_dict", + ) + .into_pointer_value(); + + env.builder.build_load(output_ptr, "load_output_ptr") +} + +#[allow(clippy::too_many_arguments)] +pub fn dict_difference<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + dict1: BasicValueEnum<'ctx>, + dict2: BasicValueEnum<'ctx>, + key_layout: &Layout<'a>, + value_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + dict_intersect_or_difference( + env, + layout_ids, + dict1, + dict2, + key_layout, + value_layout, + &bitcode::DICT_DIFFERENCE, + ) +} + +#[allow(clippy::too_many_arguments)] +pub fn dict_intersection<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + dict1: BasicValueEnum<'ctx>, + dict2: BasicValueEnum<'ctx>, + key_layout: &Layout<'a>, + value_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + dict_intersect_or_difference( + env, + layout_ids, + dict1, + dict2, + key_layout, + value_layout, + &bitcode::DICT_INTERSECTION, + ) +} + +#[allow(clippy::too_many_arguments)] +fn dict_intersect_or_difference<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + dict1: BasicValueEnum<'ctx>, + dict2: BasicValueEnum<'ctx>, + key_layout: &Layout<'a>, + value_layout: &Layout<'a>, + op: &str, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); + + let dict1_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); + let dict2_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); + + env.builder.build_store( + dict1_ptr, + struct_to_zig_dict(env, dict1.into_struct_value()), + ); + + env.builder.build_store( + dict2_ptr, + struct_to_zig_dict(env, dict2.into_struct_value()), + ); + + let key_width = env + .ptr_int() + .const_int(key_layout.stack_size(env.ptr_bytes) as u64, false); + + let value_width = env + .ptr_int() + .const_int(value_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); + let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); + + let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); + let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); + + let dec_key_fn = build_dec_wrapper(env, layout_ids, key_layout); + let dec_value_fn = build_dec_wrapper(env, layout_ids, value_layout); + + let output_ptr = builder.build_alloca(zig_dict_type, "output_ptr"); + + call_void_bitcode_fn( + env, + &[ + dict1_ptr.into(), + dict2_ptr.into(), + alignment_iv.into(), + key_width.into(), + value_width.into(), + hash_fn.as_global_value().as_pointer_value().into(), + eq_fn.as_global_value().as_pointer_value().into(), + dec_key_fn.as_global_value().as_pointer_value().into(), + dec_value_fn.as_global_value().as_pointer_value().into(), + output_ptr.into(), + ], + op, + ); + + let output_ptr = env + .builder + .build_bitcast( + output_ptr, + convert::dict(env.context, env.ptr_bytes).ptr_type(AddressSpace::Generic), + "to_roc_dict", + ) + .into_pointer_value(); + + env.builder.build_load(output_ptr, "load_output_ptr") +} + +#[allow(clippy::too_many_arguments)] +pub fn dict_walk<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + dict: BasicValueEnum<'ctx>, + stepper: BasicValueEnum<'ctx>, + accum: BasicValueEnum<'ctx>, + stepper_layout: &Layout<'a>, + key_layout: &Layout<'a>, + value_layout: &Layout<'a>, + accum_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); + + let dict_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); + env.builder + .build_store(dict_ptr, struct_to_zig_dict(env, dict.into_struct_value())); + + let stepper_ptr = builder.build_alloca(stepper.get_type(), "stepper_ptr"); + env.builder.build_store(stepper_ptr, stepper); + + let stepper_caller = build_transform_caller( + env, + layout_ids, + stepper_layout, + &[ + key_layout.clone(), + value_layout.clone(), + accum_layout.clone(), + ], + ) + .as_global_value() + .as_pointer_value(); + + let accum_bt = basic_type_from_layout(env.arena, env.context, accum_layout, env.ptr_bytes); + let accum_ptr = builder.build_alloca(accum_bt, "accum_ptr"); + env.builder.build_store(accum_ptr, accum); + + let key_width = env + .ptr_int() + .const_int(key_layout.stack_size(env.ptr_bytes) as u64, false); + + let value_width = env + .ptr_int() + .const_int(value_layout.stack_size(env.ptr_bytes) as u64, false); + + let accum_width = env + .ptr_int() + .const_int(accum_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); + let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); + + let output_ptr = builder.build_alloca(accum_bt, "output_ptr"); + + call_void_bitcode_fn( + env, + &[ + dict_ptr.into(), + env.builder.build_bitcast(stepper_ptr, u8_ptr, "to_opaque"), + stepper_caller.into(), + env.builder.build_bitcast(accum_ptr, u8_ptr, "to_opaque"), + alignment_iv.into(), + key_width.into(), + value_width.into(), + accum_width.into(), + env.builder.build_bitcast(output_ptr, u8_ptr, "to_opaque"), + ], + &bitcode::DICT_WALK, + ); + + env.builder.build_load(output_ptr, "load_output_ptr") +} + #[allow(clippy::too_many_arguments)] pub fn dict_values<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, @@ -507,7 +777,7 @@ pub fn dict_values<'a, 'ctx, 'env>( let builder = env.builder; let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); - let zig_list_type = env.module.get_struct_type("dict.RocList").unwrap(); + let zig_list_type = env.module.get_struct_type("list.RocList").unwrap(); let dict_ptr = builder.build_alloca(zig_dict_type, "dict_ptr"); env.builder @@ -524,7 +794,7 @@ pub fn dict_values<'a, 'ctx, 'env>( let alignment = Alignment::from_key_value_layout(key_layout, value_layout, env.ptr_bytes); let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); - let inc_value_fn = build_rc_wrapper(env, layout_ids, value_layout, Mode::Inc(1)); + let inc_value_fn = build_inc_wrapper(env, layout_ids, value_layout); let list_ptr = builder.build_alloca(zig_list_type, "list_ptr"); @@ -553,6 +823,68 @@ pub fn dict_values<'a, 'ctx, 'env>( env.builder.build_load(list_ptr, "load_keys_list") } +#[allow(clippy::too_many_arguments)] +pub fn set_from_list<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + list: BasicValueEnum<'ctx>, + key_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let zig_dict_type = env.module.get_struct_type("dict.RocDict").unwrap(); + + let list_alloca = builder.build_alloca(list.get_type(), "list_alloca"); + let list_ptr = env.builder.build_bitcast( + list_alloca, + env.context.i128_type().ptr_type(AddressSpace::Generic), + "to_zig_list", + ); + + env.builder.build_store(list_alloca, list); + + let key_width = env + .ptr_int() + .const_int(key_layout.stack_size(env.ptr_bytes) as u64, false); + + let value_width = env.ptr_int().const_zero(); + + let result_alloca = + builder.build_alloca(convert::dict(env.context, env.ptr_bytes), "result_alloca"); + let result_ptr = builder.build_bitcast( + result_alloca, + zig_dict_type.ptr_type(AddressSpace::Generic), + "to_zig_dict", + ); + + let alignment = + Alignment::from_key_value_layout(key_layout, &Layout::Struct(&[]), env.ptr_bytes); + let alignment_iv = env.context.i8_type().const_int(alignment as u64, false); + + let hash_fn = build_hash_wrapper(env, layout_ids, key_layout); + let eq_fn = build_eq_wrapper(env, layout_ids, key_layout); + + let dec_key_fn = build_dec_wrapper(env, layout_ids, key_layout); + + call_void_bitcode_fn( + env, + &[ + env.builder + .build_load(list_ptr.into_pointer_value(), "as_i128"), + alignment_iv.into(), + key_width.into(), + value_width.into(), + hash_fn.as_global_value().as_pointer_value().into(), + eq_fn.as_global_value().as_pointer_value().into(), + dec_key_fn.as_global_value().as_pointer_value().into(), + result_ptr, + ], + &bitcode::SET_FROM_LIST, + ); + + env.builder.build_load(result_alloca, "load_result") +} + fn build_hash_wrapper<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -622,158 +954,6 @@ fn build_hash_wrapper<'a, 'ctx, 'env>( function_value } -fn build_eq_wrapper<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, -) -> FunctionValue<'ctx> { - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - - let symbol = Symbol::GENERIC_EQ_REF; - let fn_name = layout_ids - .get(symbol, &layout) - .to_symbol_string(symbol, &env.interns); - - let function_value = match env.module.get_function(fn_name.as_str()) { - Some(function_value) => function_value, - None => { - let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); - - let function_value = crate::llvm::refcounting::build_header_help( - env, - &fn_name, - env.context.bool_type().into(), - &[arg_type.into(), arg_type.into()], - ); - - let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); - debug_assert!(kind_id > 0); - let attr = env.context.create_enum_attribute(kind_id, 1); - function_value.add_attribute(AttributeLoc::Function, attr); - - let entry = env.context.append_basic_block(function_value, "entry"); - env.builder.position_at_end(entry); - - debug_info_init!(env, function_value); - - let mut it = function_value.get_param_iter(); - let value_ptr1 = it.next().unwrap().into_pointer_value(); - let value_ptr2 = it.next().unwrap().into_pointer_value(); - - set_name(value_ptr1.into(), Symbol::ARG_1.ident_string(&env.interns)); - set_name(value_ptr2.into(), Symbol::ARG_2.ident_string(&env.interns)); - - let value_type = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes) - .ptr_type(AddressSpace::Generic); - - let value_cast1 = env - .builder - .build_bitcast(value_ptr1, value_type, "load_opaque") - .into_pointer_value(); - - let value_cast2 = env - .builder - .build_bitcast(value_ptr2, value_type, "load_opaque") - .into_pointer_value(); - - let value1 = env.builder.build_load(value_cast1, "load_opaque"); - let value2 = env.builder.build_load(value_cast2, "load_opaque"); - - let result = - crate::llvm::compare::generic_eq(env, layout_ids, value1, value2, layout, layout); - - env.builder.build_return(Some(&result)); - - function_value - } - }; - - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - - function_value -} - -fn build_rc_wrapper<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - layout: &Layout<'a>, - rc_operation: Mode, -) -> FunctionValue<'ctx> { - let block = env.builder.get_insert_block().expect("to be in a function"); - let di_location = env.builder.get_current_debug_location().unwrap(); - - let symbol = Symbol::GENERIC_RC_REF; - let fn_name = layout_ids - .get(symbol, &layout) - .to_symbol_string(symbol, &env.interns); - - let fn_name = match rc_operation { - Mode::Inc(n) => format!("{}_inc_{}", fn_name, n), - Mode::Dec => format!("{}_dec", fn_name), - }; - - let function_value = match env.module.get_function(fn_name.as_str()) { - Some(function_value) => function_value, - None => { - let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); - - let function_value = crate::llvm::refcounting::build_header_help( - env, - &fn_name, - env.context.void_type().into(), - &[arg_type.into()], - ); - - let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); - debug_assert!(kind_id > 0); - let attr = env.context.create_enum_attribute(kind_id, 1); - function_value.add_attribute(AttributeLoc::Function, attr); - - let entry = env.context.append_basic_block(function_value, "entry"); - env.builder.position_at_end(entry); - - debug_info_init!(env, function_value); - - let mut it = function_value.get_param_iter(); - let value_ptr = it.next().unwrap().into_pointer_value(); - - set_name(value_ptr.into(), Symbol::ARG_1.ident_string(&env.interns)); - - let value_type = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes) - .ptr_type(AddressSpace::Generic); - - let value_cast = env - .builder - .build_bitcast(value_ptr, value_type, "load_opaque") - .into_pointer_value(); - - let value = env.builder.build_load(value_cast, "load_opaque"); - - match rc_operation { - Mode::Inc(n) => { - increment_refcount_layout(env, function_value, layout_ids, n, value, layout); - } - Mode::Dec => { - decrement_refcount_layout(env, function_value, layout_ids, value, layout); - } - } - - env.builder.build_return(None); - - function_value - } - }; - - env.builder.position_at_end(block); - env.builder - .set_current_debug_location(env.context, di_location); - - function_value -} - fn dict_symbol_to_zig_dict<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, scope: &Scope<'a, 'ctx>, diff --git a/compiler/gen/src/llvm/build_hash.rs b/compiler/gen/src/llvm/build_hash.rs index a5445e1a39..0da80d7fc6 100644 --- a/compiler/gen/src/llvm/build_hash.rs +++ b/compiler/gen/src/llvm/build_hash.rs @@ -1,8 +1,7 @@ use crate::debug_info_init; +use crate::llvm::bitcode::call_bitcode_fn; use crate::llvm::build::Env; -use crate::llvm::build::{ - call_bitcode_fn, cast_block_of_memory_to_tag, complex_bitcast, set_name, FAST_CALL_CONV, -}; +use crate::llvm::build::{cast_block_of_memory_to_tag, complex_bitcast, set_name, FAST_CALL_CONV}; use crate::llvm::build_str; use crate::llvm::convert::basic_type_from_layout; use bumpalo::collections::Vec; diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 48e7e6797f..c2f3639f0d 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1,17 +1,21 @@ -use crate::llvm::build::{ - allocate_with_refcount_help, build_num_binop, cast_basic_basic, Env, InPlace, +#![allow(clippy::too_many_arguments)] +use crate::llvm::bitcode::{ + build_eq_wrapper, build_transform_caller, call_bitcode_fn, call_void_bitcode_fn, +}; +use crate::llvm::build::{ + allocate_with_refcount_help, build_num_binop, cast_basic_basic, complex_bitcast, Env, InPlace, }; -use crate::llvm::compare::generic_eq; use crate::llvm::convert::{basic_type_from_layout, collection, get_ptr_type}; use crate::llvm::refcounting::{ - decrement_refcount_layout, increment_refcount_layout, refcount_is_one_comparison, - PointerToRefcount, + increment_refcount_layout, refcount_is_one_comparison, PointerToRefcount, }; use inkwell::builder::Builder; use inkwell::context::Context; +use inkwell::types::BasicType; use inkwell::types::{BasicTypeEnum, PointerType}; use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue}; use inkwell::{AddressSpace, IntPredicate}; +use roc_builtins::bitcode; use roc_mono::layout::{Builtin, Layout, LayoutIds, MemoryMode}; /// List.single : a -> List a @@ -461,8 +465,6 @@ pub fn list_reverse<'a, 'ctx, 'env>( list: BasicValueEnum<'ctx>, list_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - use inkwell::types::BasicType; - let builder = env.builder; let ctx = env.context; @@ -845,674 +847,423 @@ pub fn list_sum<'a, 'ctx, 'env>( } /// List.walk : List elem, (elem -> accum -> accum), accum -> accum -#[allow(clippy::too_many_arguments)] pub fn list_walk<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, parent: FunctionValue<'ctx>, list: BasicValueEnum<'ctx>, - list_layout: &Layout<'a>, + element_layout: &Layout<'a>, func: BasicValueEnum<'ctx>, func_layout: &Layout<'a>, default: BasicValueEnum<'ctx>, default_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - let ctx = env.context; - let builder = env.builder; - - let list_wrapper = list.into_struct_value(); - let len = list_len(env.builder, list_wrapper); - - let accum_type = basic_type_from_layout(env.arena, ctx, default_layout, env.ptr_bytes); - let accum_alloca = builder.build_alloca(accum_type, "alloca_walk_right_accum"); - builder.build_store(accum_alloca, default); - - let then_block = ctx.append_basic_block(parent, "then"); - let cont_block = ctx.append_basic_block(parent, "branchcont"); - - let condition = builder.build_int_compare( - IntPredicate::UGT, - len, - ctx.i64_type().const_zero(), - "list_non_empty", - ); - - builder.build_conditional_branch(condition, then_block, cont_block); - - builder.position_at_end(then_block); - - match (func, func_layout) { - (BasicValueEnum::PointerValue(func_ptr), Layout::FunctionPointer(_, _)) => { - let elem_layout = match list_layout { - Layout::Builtin(Builtin::List(_, layout)) => layout, - _ => unreachable!("can only fold over a list"), - }; - - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let elem_ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); - - let list_ptr = load_list_ptr(builder, list_wrapper, elem_ptr_type); - - let walk_right_loop = |_, elem: BasicValueEnum<'ctx>| { - // load current accumulator - let current = builder.build_load(accum_alloca, "retrieve_accum"); - - let call_site_value = - builder.build_call(func_ptr, &[elem, current], "#walk_right_func"); - - // set the calling convention explicitly for this call - call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); - - let new_current = call_site_value - .try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")); - - builder.build_store(accum_alloca, new_current); - }; - - incrementing_elem_loop( - builder, - ctx, - parent, - list_ptr, - len, - "#index", - walk_right_loop, - ); - } - - _ => { - unreachable!( - "Invalid function basic value enum or layout for List.keepIf : {:?}", - (func, func_layout) - ); - } - } - - builder.build_unconditional_branch(cont_block); - - builder.position_at_end(cont_block); - - builder.build_load(accum_alloca, "load_final_acum") + list_walk_generic( + env, + layout_ids, + parent, + list, + element_layout, + func, + func_layout, + default, + default_layout, + &bitcode::LIST_WALK, + ) } /// List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum -#[allow(clippy::too_many_arguments)] pub fn list_walk_backwards<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, parent: FunctionValue<'ctx>, list: BasicValueEnum<'ctx>, - list_layout: &Layout<'a>, + element_layout: &Layout<'a>, func: BasicValueEnum<'ctx>, func_layout: &Layout<'a>, default: BasicValueEnum<'ctx>, default_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - let ctx = env.context; + list_walk_generic( + env, + layout_ids, + parent, + list, + element_layout, + func, + func_layout, + default, + default_layout, + &bitcode::LIST_WALK_BACKWARDS, + ) +} + +fn list_walk_generic<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + _parent: FunctionValue<'ctx>, + list: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, + func: BasicValueEnum<'ctx>, + func_layout: &Layout<'a>, + default: BasicValueEnum<'ctx>, + default_layout: &Layout<'a>, + zig_function: &str, +) -> BasicValueEnum<'ctx> { let builder = env.builder; - let list_wrapper = list.into_struct_value(); - let len = list_len(env.builder, list_wrapper); + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); - let accum_type = basic_type_from_layout(env.arena, ctx, default_layout, env.ptr_bytes); - let accum_alloca = builder.build_alloca(accum_type, "alloca_walk_right_accum"); - builder.build_store(accum_alloca, default); + let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128"); - let then_block = ctx.append_basic_block(parent, "then"); - let cont_block = ctx.append_basic_block(parent, "branchcont"); + let transform_ptr = builder.build_alloca(func.get_type(), "transform_ptr"); + env.builder.build_store(transform_ptr, func); - let condition = builder.build_int_compare( - IntPredicate::UGT, - len, - ctx.i64_type().const_zero(), - "list_non_empty", + let default_ptr = builder.build_alloca(default.get_type(), "default_ptr"); + env.builder.build_store(default_ptr, default); + + let stepper_caller = build_transform_caller( + env, + layout_ids, + func_layout, + &[element_layout.clone(), default_layout.clone()], + ) + .as_global_value() + .as_pointer_value(); + + let element_width = env + .ptr_int() + .const_int(element_layout.stack_size(env.ptr_bytes) as u64, false); + + let default_width = env + .ptr_int() + .const_int(default_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = element_layout.alignment_bytes(env.ptr_bytes); + let alignment_iv = env.ptr_int().const_int(alignment as u64, false); + + let result_ptr = env.builder.build_alloca(default.get_type(), "result"); + + call_void_bitcode_fn( + env, + &[ + list_i128, + env.builder + .build_bitcast(transform_ptr, u8_ptr, "to_opaque"), + stepper_caller.into(), + env.builder.build_bitcast(default_ptr, u8_ptr, "to_u8_ptr"), + alignment_iv.into(), + element_width.into(), + default_width.into(), + env.builder.build_bitcast(result_ptr, u8_ptr, "to_opaque"), + ], + zig_function, ); - builder.build_conditional_branch(condition, then_block, cont_block); - - builder.position_at_end(then_block); - - match (func, func_layout) { - (BasicValueEnum::PointerValue(func_ptr), Layout::FunctionPointer(_, _)) => { - let elem_layout = match list_layout { - Layout::Builtin(Builtin::List(_, layout)) => layout, - _ => unreachable!("can only fold over a list"), - }; - - let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); - let elem_ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); - - let list_ptr = load_list_ptr(builder, list_wrapper, elem_ptr_type); - - let walk_right_loop = |_, elem: BasicValueEnum<'ctx>| { - // load current accumulator - let current = builder.build_load(accum_alloca, "retrieve_accum"); - - let call_site_value = - builder.build_call(func_ptr, &[elem, current], "#walk_right_func"); - - // set the calling convention explicitly for this call - call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); - - let new_current = call_site_value - .try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")); - - builder.build_store(accum_alloca, new_current); - }; - - decrementing_elem_loop( - builder, - ctx, - parent, - list_ptr, - len, - "#index", - walk_right_loop, - ); - } - - _ => { - unreachable!( - "Invalid function basic value enum or layout for List.keepIf : {:?}", - (func, func_layout) - ); - } - } - - builder.build_unconditional_branch(cont_block); - - builder.position_at_end(cont_block); - - builder.build_load(accum_alloca, "load_final_acum") + env.builder.build_load(result_ptr, "load_result") } /// List.contains : List elem, elem -> Bool pub fn list_contains<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - parent: FunctionValue<'ctx>, - elem: BasicValueEnum<'ctx>, - elem_layout: &Layout<'a>, + element: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, list: BasicValueEnum<'ctx>, - list_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - use inkwell::types::BasicType; - let builder = env.builder; - let wrapper_struct = list.into_struct_value(); - let list_elem_layout = match &list_layout { - // this pointer will never actually be dereferenced - Layout::Builtin(Builtin::EmptyList) => &Layout::Builtin(Builtin::Int64), - Layout::Builtin(Builtin::List(_, element_layout)) => element_layout, - _ => unreachable!("Invalid layout {:?} in List.contains", list_layout), - }; + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); - let list_elem_type = - basic_type_from_layout(env.arena, env.context, list_elem_layout, env.ptr_bytes); + let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128"); - let list_ptr = load_list_ptr( - builder, - wrapper_struct, - list_elem_type.ptr_type(AddressSpace::Generic), - ); + let key_ptr = builder.build_alloca(element.get_type(), "key_ptr"); + env.builder.build_store(key_ptr, element); - let length = list_len(builder, list.into_struct_value()); + let element_width = env + .ptr_int() + .const_int(element_layout.stack_size(env.ptr_bytes) as u64, false); - list_contains_help( + let eq_fn = build_eq_wrapper(env, layout_ids, element_layout); + + call_bitcode_fn( env, - layout_ids, - parent, - length, - list_ptr, - list_elem_layout, - elem, - elem_layout, + &[ + list_i128, + env.builder.build_bitcast(key_ptr, u8_ptr, "to_u8_ptr"), + element_width.into(), + eq_fn.as_global_value().as_pointer_value().into(), + ], + bitcode::LIST_CONTAINS, ) } -#[allow(clippy::too_many_arguments)] -pub fn list_contains_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - layout_ids: &mut LayoutIds<'a>, - parent: FunctionValue<'ctx>, - length: IntValue<'ctx>, - source_ptr: PointerValue<'ctx>, - list_elem_layout: &Layout<'a>, - elem: BasicValueEnum<'ctx>, - elem_layout: &Layout<'a>, -) -> BasicValueEnum<'ctx> { - let builder = env.builder; - let ctx = env.context; - - let bool_alloca = builder.build_alloca(ctx.bool_type(), "bool_alloca"); - let index_alloca = builder.build_alloca(ctx.i64_type(), "index_alloca"); - let next_free_index_alloca = builder.build_alloca(ctx.i64_type(), "next_free_index_alloca"); - - builder.build_store(bool_alloca, ctx.bool_type().const_zero()); - builder.build_store(index_alloca, ctx.i64_type().const_zero()); - builder.build_store(next_free_index_alloca, ctx.i64_type().const_zero()); - - let condition_bb = ctx.append_basic_block(parent, "condition"); - builder.build_unconditional_branch(condition_bb); - builder.position_at_end(condition_bb); - - let index = builder.build_load(index_alloca, "index").into_int_value(); - - let condition = builder.build_int_compare(IntPredicate::SGT, length, index, "loopcond"); - - let body_bb = ctx.append_basic_block(parent, "body"); - let cont_bb = ctx.append_basic_block(parent, "cont"); - builder.build_conditional_branch(condition, body_bb, cont_bb); - - // loop body - builder.position_at_end(body_bb); - - let current_elem_ptr = unsafe { builder.build_in_bounds_gep(source_ptr, &[index], "elem_ptr") }; - - let current_elem = builder.build_load(current_elem_ptr, "load_elem"); - - let has_found = generic_eq( - env, - layout_ids, - current_elem, - elem, - list_elem_layout, - elem_layout, - ); - - builder.build_store(bool_alloca, has_found.into_int_value()); - - let one = ctx.i64_type().const_int(1, false); - - let next_free_index = builder - .build_load(next_free_index_alloca, "load_next_free") - .into_int_value(); - - builder.build_store( - next_free_index_alloca, - builder.build_int_add(next_free_index, one, "incremented_next_free_index"), - ); - - builder.build_store( - index_alloca, - builder.build_int_add(index, one, "incremented_index"), - ); - - builder.build_conditional_branch(has_found.into_int_value(), cont_bb, condition_bb); - - // continuation - builder.position_at_end(cont_bb); - - builder.build_load(bool_alloca, "answer") -} - /// List.keepIf : List elem, (elem -> Bool) -> List elem -#[allow(clippy::too_many_arguments)] pub fn list_keep_if<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - output_inplace: InPlace, - parent: FunctionValue<'ctx>, - func: BasicValueEnum<'ctx>, - func_layout: &Layout<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, list: BasicValueEnum<'ctx>, - list_layout: &Layout<'a>, + element_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - use inkwell::types::BasicType; - let builder = env.builder; - let ctx = env.context; - let wrapper_struct = list.into_struct_value(); - let (input_inplace, element_layout) = match list_layout.clone() { - Layout::Builtin(Builtin::EmptyList) => ( - InPlace::InPlace, - // this pointer will never actually be dereferenced - Layout::Builtin(Builtin::Int64), - ), - Layout::Builtin(Builtin::List(memory_mode, elem_layout)) => ( - match memory_mode { - MemoryMode::Unique => InPlace::InPlace, - MemoryMode::Refcounted => InPlace::Clone, - }, - elem_layout.clone(), - ), + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); - _ => unreachable!("Invalid layout {:?} in List.keepIf", list_layout), + let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128"); + + let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); + env.builder.build_store(transform_ptr, transform); + + let stepper_caller = + build_transform_caller(env, layout_ids, transform_layout, &[element_layout.clone()]) + .as_global_value() + .as_pointer_value(); + + let element_width = env + .ptr_int() + .const_int(element_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = element_layout.alignment_bytes(env.ptr_bytes); + let alignment_iv = env.ptr_int().const_int(alignment as u64, false); + + let output = call_bitcode_fn( + env, + &[ + list_i128, + env.builder + .build_bitcast(transform_ptr, u8_ptr, "to_opaque"), + stepper_caller.into(), + alignment_iv.into(), + element_width.into(), + ], + &bitcode::LIST_KEEP_IF, + ); + + complex_bitcast( + env.builder, + output, + collection(env.context, env.ptr_bytes).into(), + "from_i128", + ) +} + +/// List.keepOks : List before, (before -> Result after *) -> List after +pub fn list_keep_oks<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + before_layout: &Layout<'a>, + after_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_keep_result( + env, + layout_ids, + transform, + transform_layout, + list, + before_layout, + after_layout, + bitcode::LIST_KEEP_OKS, + ) +} + +/// List.keepErrs : List before, (before -> Result * after) -> List after +pub fn list_keep_errs<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + before_layout: &Layout<'a>, + after_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_keep_result( + env, + layout_ids, + transform, + transform_layout, + list, + before_layout, + after_layout, + bitcode::LIST_KEEP_ERRS, + ) +} + +pub fn list_keep_result<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + before_layout: &Layout<'a>, + after_layout: &Layout<'a>, + op: &str, +) -> BasicValueEnum<'ctx> { + let builder = env.builder; + + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let result_layout = match transform_layout { + Layout::FunctionPointer(_, ret) => ret, + Layout::Closure(_, _, ret) => ret, + _ => unreachable!("not a callable layout"), }; - let list_type = basic_type_from_layout(env.arena, env.context, &list_layout, env.ptr_bytes); - let elem_type = basic_type_from_layout(env.arena, env.context, &element_layout, env.ptr_bytes); - let ptr_type = elem_type.ptr_type(AddressSpace::Generic); + let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128"); - let list_ptr = load_list_ptr(builder, wrapper_struct, ptr_type); - let length = list_len(builder, list.into_struct_value()); + let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); + env.builder.build_store(transform_ptr, transform); - let zero = ctx.i64_type().const_zero(); + let stepper_caller = + build_transform_caller(env, layout_ids, transform_layout, &[before_layout.clone()]) + .as_global_value() + .as_pointer_value(); - match input_inplace { - InPlace::InPlace => { - let new_length = list_keep_if_help( - env, - input_inplace, - parent, - length, - list_ptr, - list_ptr, - func, - func_layout, - ); + let before_width = env + .ptr_int() + .const_int(before_layout.stack_size(env.ptr_bytes) as u64, false); - store_list(env, list_ptr, new_length) - } - InPlace::Clone => { - let len_0_block = ctx.append_basic_block(parent, "len_0_block"); - let len_n_block = ctx.append_basic_block(parent, "len_n_block"); - let cont_block = ctx.append_basic_block(parent, "cont_block"); + let after_width = env + .ptr_int() + .const_int(after_layout.stack_size(env.ptr_bytes) as u64, false); - let result = builder.build_alloca(list_type, "result"); + let result_width = env + .ptr_int() + .const_int(result_layout.stack_size(env.ptr_bytes) as u64, false); - builder.build_switch(length, len_n_block, &[(zero, len_0_block)]); + let alignment = before_layout.alignment_bytes(env.ptr_bytes); + let alignment_iv = env.ptr_int().const_int(alignment as u64, false); - // build block for length 0 - { - builder.position_at_end(len_0_block); + let output = call_bitcode_fn( + env, + &[ + list_i128, + env.builder + .build_bitcast(transform_ptr, u8_ptr, "to_opaque"), + stepper_caller.into(), + alignment_iv.into(), + before_width.into(), + result_width.into(), + after_width.into(), + ], + op, + ); - let new_list = store_list(env, ptr_type.const_zero(), zero); - - builder.build_store(result, new_list); - builder.build_unconditional_branch(cont_block); - } - - // build block for length > 0 - { - builder.position_at_end(len_n_block); - - let new_list_ptr = allocate_list(env, output_inplace, &element_layout, length); - - let new_length = list_keep_if_help( - env, - InPlace::Clone, - parent, - length, - list_ptr, - new_list_ptr, - func, - func_layout, - ); - - // store new list pointer there - let new_list = store_list(env, new_list_ptr, new_length); - - builder.build_store(result, new_list); - builder.build_unconditional_branch(cont_block); - } - - builder.position_at_end(cont_block); - - // consume the input list - decrement_refcount_layout(env, parent, layout_ids, list, list_layout); - - builder.build_load(result, "load_result") - } - } -} - -#[allow(clippy::too_many_arguments)] -pub fn list_keep_if_help<'a, 'ctx, 'env>( - env: &Env<'a, 'ctx, 'env>, - _inplace: InPlace, - parent: FunctionValue<'ctx>, - length: IntValue<'ctx>, - source_ptr: PointerValue<'ctx>, - dest_ptr: PointerValue<'ctx>, - func: BasicValueEnum<'ctx>, - func_layout: &Layout<'a>, -) -> IntValue<'ctx> { - match (func, func_layout) { - ( - BasicValueEnum::PointerValue(func_ptr), - Layout::FunctionPointer(_, Layout::Builtin(Builtin::Int1)), - ) => { - let builder = env.builder; - let ctx = env.context; - - let index_alloca = builder.build_alloca(ctx.i64_type(), "index_alloca"); - let next_free_index_alloca = - builder.build_alloca(ctx.i64_type(), "next_free_index_alloca"); - - builder.build_store(index_alloca, ctx.i64_type().const_zero()); - builder.build_store(next_free_index_alloca, ctx.i64_type().const_zero()); - - // while (length > next_index) - let condition_bb = ctx.append_basic_block(parent, "condition"); - builder.build_unconditional_branch(condition_bb); - builder.position_at_end(condition_bb); - - let index = builder.build_load(index_alloca, "index").into_int_value(); - - let condition = builder.build_int_compare(IntPredicate::SGT, length, index, "loopcond"); - - let body_bb = ctx.append_basic_block(parent, "body"); - let cont_bb = ctx.append_basic_block(parent, "cont"); - builder.build_conditional_branch(condition, body_bb, cont_bb); - - // loop body - builder.position_at_end(body_bb); - - let elem_ptr = unsafe { builder.build_in_bounds_gep(source_ptr, &[index], "elem_ptr") }; - - let elem = builder.build_load(elem_ptr, "load_elem"); - - let call_site_value = - builder.build_call(func_ptr, env.arena.alloc([elem]), "#keep_if_insert_func"); - - // set the calling convention explicitly for this call - call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); - - let should_keep = call_site_value - .try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) - .into_int_value(); - - let filter_pass_bb = ctx.append_basic_block(parent, "loop"); - let after_filter_pass_bb = ctx.append_basic_block(parent, "after_loop"); - - let one = ctx.i64_type().const_int(1, false); - - builder.build_conditional_branch(should_keep, filter_pass_bb, after_filter_pass_bb); - builder.position_at_end(filter_pass_bb); - - let next_free_index = builder - .build_load(next_free_index_alloca, "load_next_free") - .into_int_value(); - - // TODO if next_free_index equals index, and we are mutating in place, - // then maybe we should not write this value back into memory - let dest_elem_ptr = unsafe { - builder.build_in_bounds_gep(dest_ptr, &[next_free_index], "dest_elem_ptr") - }; - - builder.build_store(dest_elem_ptr, elem); - - builder.build_store( - next_free_index_alloca, - builder.build_int_add(next_free_index, one, "incremented_next_free_index"), - ); - - builder.build_unconditional_branch(after_filter_pass_bb); - builder.position_at_end(after_filter_pass_bb); - - builder.build_store( - index_alloca, - builder.build_int_add(index, one, "incremented_index"), - ); - - builder.build_unconditional_branch(condition_bb); - - // continuation - builder.position_at_end(cont_bb); - - builder - .build_load(next_free_index_alloca, "new_length") - .into_int_value() - } - _ => unreachable!( - "Invalid function basic value enum or layout for List.keepIf : {:?}", - (func, func_layout) - ), - } + complex_bitcast( + env.builder, + output, + collection(env.context, env.ptr_bytes).into(), + "from_i128", + ) } /// List.map : List before, (before -> after) -> List after -macro_rules! list_map_help { - ($env:expr, $layout_ids:expr, $inplace:expr, $parent:expr, $func:expr, $func_layout:expr, $list:expr, $list_layout:expr, $function_ptr:expr, $function_return_layout: expr, $closure_info:expr) => {{ - let layout_ids = $layout_ids; - let inplace = $inplace; - let parent = $parent; - let func = $func; - let func_layout = $func_layout; - let list = $list; - let list_layout = $list_layout; - let function_ptr = $function_ptr; - let function_return_layout = $function_return_layout; - let closure_info : Option<(&Layout, BasicValueEnum)> = $closure_info; - - - let non_empty_fn = |elem_layout: &Layout<'a>, - len: IntValue<'ctx>, - list_wrapper: StructValue<'ctx>| { - let ctx = $env.context; - let builder = $env.builder; - - let ret_list_ptr = allocate_list($env, inplace, function_return_layout, len); - - let elem_type = basic_type_from_layout($env.arena, ctx, elem_layout, $env.ptr_bytes); - let ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); - - let list_ptr = load_list_ptr(builder, list_wrapper, ptr_type); - - let list_loop = |index, before_elem| { - increment_refcount_layout($env, parent, layout_ids, 1, before_elem, elem_layout); - - let arguments = match closure_info { - Some((closure_data_layout, closure_data)) => { - increment_refcount_layout( $env, parent, layout_ids, 1, closure_data, closure_data_layout); - - bumpalo::vec![in $env.arena; before_elem, closure_data] - } - None => bumpalo::vec![in $env.arena; before_elem], - }; - - - let call_site_value = builder.build_call(function_ptr, &arguments, "map_func"); - - // set the calling convention explicitly for this call - call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); - - let after_elem = call_site_value - .try_as_basic_value() - .left() - .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")); - - // The pointer to the element in the mapped-over list - let after_elem_ptr = unsafe { - builder.build_in_bounds_gep(ret_list_ptr, &[index], "load_index_after_list") - }; - - // Mutate the new array in-place to change the element. - builder.build_store(after_elem_ptr, after_elem); - }; - - incrementing_elem_loop(builder, ctx, parent, list_ptr, len, "#index", list_loop); - - let result = store_list($env, ret_list_ptr, len); - - // decrement the input list and function (if it's a closure) - decrement_refcount_layout($env, parent, layout_ids, list, list_layout); - decrement_refcount_layout($env, parent, layout_ids, func, func_layout); - - if let Some((closure_data_layout, closure_data)) = closure_info { - decrement_refcount_layout( $env, parent, layout_ids, closure_data, closure_data_layout); - } - - result - }; - - if_list_is_not_empty($env, parent, non_empty_fn, list, list_layout, "List.map") - }}; -} - -/// List.map : List before, (before -> after) -> List after -#[allow(clippy::too_many_arguments)] pub fn list_map<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, - inplace: InPlace, - parent: FunctionValue<'ctx>, - func: BasicValueEnum<'ctx>, - func_layout: &Layout<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, list: BasicValueEnum<'ctx>, - list_layout: &Layout<'a>, + element_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - match (func, func_layout) { - (BasicValueEnum::PointerValue(func_ptr), Layout::FunctionPointer(_, ret_elem_layout)) => { - list_map_help!( - env, - layout_ids, - inplace, - parent, - func, - func_layout, - list, - list_layout, - func_ptr, - ret_elem_layout, - None - ) - } - ( - BasicValueEnum::StructValue(ptr_and_data), - Layout::Closure(_, closure_layout, ret_elem_layout), - ) => { - let builder = env.builder; + list_map_generic( + env, + layout_ids, + transform, + transform_layout, + list, + element_layout, + bitcode::LIST_MAP, + &[element_layout.clone()], + ) +} - let func_ptr = builder - .build_extract_value(ptr_and_data, 0, "function_ptr") - .unwrap() - .into_pointer_value(); +/// List.mapWithIndex : List before, (Nat, before -> after) -> List after +pub fn list_map_with_index<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + list_map_generic( + env, + layout_ids, + transform, + transform_layout, + list, + element_layout, + bitcode::LIST_MAP_WITH_INDEX, + &[Layout::Builtin(Builtin::Usize), element_layout.clone()], + ) +} - let closure_data = builder - .build_extract_value(ptr_and_data, 1, "closure_data") - .unwrap(); +fn list_map_generic<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + transform: BasicValueEnum<'ctx>, + transform_layout: &Layout<'a>, + list: BasicValueEnum<'ctx>, + element_layout: &Layout<'a>, + op: &str, + argument_layouts: &[Layout<'a>], +) -> BasicValueEnum<'ctx> { + let builder = env.builder; - let closure_data_layout = closure_layout.as_block_of_memory_layout(); + let return_layout = match transform_layout { + Layout::FunctionPointer(_, ret) => ret, + Layout::Closure(_, _, ret) => ret, + _ => unreachable!("not a callable layout"), + }; - list_map_help!( - env, - layout_ids, - inplace, - parent, - func, - func_layout, - list, - list_layout, - func_ptr, - ret_elem_layout, - Some((&closure_data_layout, closure_data)) - ) - } - _ => { - unreachable!( - "Invalid function basic value enum or layout for List.map : {:?}", - (func, func_layout) - ); - } - } + let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128"); + + let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); + env.builder.build_store(transform_ptr, transform); + + let stepper_caller = + build_transform_caller(env, layout_ids, transform_layout, argument_layouts) + .as_global_value() + .as_pointer_value(); + + let old_element_width = env + .ptr_int() + .const_int(element_layout.stack_size(env.ptr_bytes) as u64, false); + + let new_element_width = env + .ptr_int() + .const_int(return_layout.stack_size(env.ptr_bytes) as u64, false); + + let alignment = element_layout.alignment_bytes(env.ptr_bytes); + let alignment_iv = env.ptr_int().const_int(alignment as u64, false); + + let output = call_bitcode_fn( + env, + &[ + list_i128, + env.builder + .build_bitcast(transform_ptr, u8_ptr, "to_opaque"), + stepper_caller.into(), + alignment_iv.into(), + old_element_width.into(), + new_element_width.into(), + ], + op, + ); + + complex_bitcast( + env.builder, + output, + collection(env.context, env.ptr_bytes).into(), + "from_i128", + ) } /// List.concat : List elem, List elem -> List elem @@ -1856,59 +1607,6 @@ where index_alloca } -// This function checks if the list is empty, and -// if it is, it returns an empty list, and if not -// it runs whatever code is passed in under `build_non_empty` -// This is the avoid allocating memory if the list is empty. -fn if_list_is_not_empty<'a, 'ctx, 'env, 'b, NonEmptyFn>( - env: &Env<'a, 'ctx, 'env>, - parent: FunctionValue<'ctx>, - mut build_non_empty: NonEmptyFn, - list: BasicValueEnum<'ctx>, - list_layout: &Layout<'a>, - list_fn_name: &str, -) -> BasicValueEnum<'ctx> -where - NonEmptyFn: FnMut(&Layout<'a>, IntValue<'ctx>, StructValue<'ctx>) -> BasicValueEnum<'ctx>, -{ - match list_layout { - Layout::Builtin(Builtin::EmptyList) => empty_list(env), - - Layout::Builtin(Builtin::List(_, elem_layout)) => { - let builder = env.builder; - let ctx = env.context; - - let wrapper_struct = list.into_struct_value(); - - let len = list_len(builder, wrapper_struct); - - // list_len > 0 - let comparison = builder.build_int_compare( - IntPredicate::UGT, - len, - ctx.i64_type().const_int(0, false), - "greaterthanzero", - ); - - let build_empty = || empty_list(env); - - let struct_type = collection(ctx, env.ptr_bytes); - - build_basic_phi2( - env, - parent, - comparison, - || build_non_empty(elem_layout, len, wrapper_struct), - build_empty, - BasicTypeEnum::StructType(struct_type), - ) - } - _ => { - unreachable!("Invalid List layout for {} {:?}", list_fn_name, list_layout); - } - } -} - pub fn build_basic_phi2<'a, 'ctx, 'env, PassFn, FailFn>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, diff --git a/compiler/gen/src/llvm/build_str.rs b/compiler/gen/src/llvm/build_str.rs index 3b1fd0dbac..2ab140db34 100644 --- a/compiler/gen/src/llvm/build_str.rs +++ b/compiler/gen/src/llvm/build_str.rs @@ -1,6 +1,5 @@ -use crate::llvm::build::{ - call_bitcode_fn, call_void_bitcode_fn, complex_bitcast, Env, InPlace, Scope, -}; +use crate::llvm::bitcode::{call_bitcode_fn, call_void_bitcode_fn}; +use crate::llvm::build::{complex_bitcast, Env, InPlace, Scope}; use crate::llvm::build_list::{allocate_list, store_list}; use crate::llvm::convert::collection; use inkwell::builder::Builder; @@ -274,6 +273,19 @@ pub fn str_from_int<'a, 'ctx, 'env>( zig_str_to_struct(env, zig_result).into() } +/// Str.fromInt : Int -> Str +pub fn str_from_float<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + scope: &Scope<'a, 'ctx>, + int_symbol: Symbol, +) -> BasicValueEnum<'ctx> { + let float = load_symbol(scope, &int_symbol); + + let zig_result = call_bitcode_fn(env, &[float], &bitcode::STR_FROM_FLOAT).into_struct_value(); + + zig_str_to_struct(env, zig_result).into() +} + /// Str.equal : Str, Str -> Bool pub fn str_equal<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, diff --git a/compiler/gen/src/llvm/compare.rs b/compiler/gen/src/llvm/compare.rs index ec1f0ebd05..cd733b3bd1 100644 --- a/compiler/gen/src/llvm/compare.rs +++ b/compiler/gen/src/llvm/compare.rs @@ -366,7 +366,6 @@ fn build_list_eq<'a, 'ctx, 'env>( list2: StructValue<'ctx>, when_recursive: WhenRecursive<'a>, ) -> BasicValueEnum<'ctx> { - dbg!("list", &when_recursive); let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index f5a319aaf8..9cefbc3fbc 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -283,6 +283,10 @@ pub fn dict(ctx: &Context, ptr_bytes: u32) -> StructType<'_> { ) } +pub fn dict_ptr(ctx: &Context, ptr_bytes: u32) -> PointerType<'_> { + dict(ctx, ptr_bytes).ptr_type(AddressSpace::Generic) +} + pub fn ptr_int(ctx: &Context, ptr_bytes: u32) -> IntType<'_> { match ptr_bytes { 1 => ctx.i8_type(), diff --git a/compiler/gen/src/llvm/mod.rs b/compiler/gen/src/llvm/mod.rs index 1d4e0bb4fd..195cbc9c98 100644 --- a/compiler/gen/src/llvm/mod.rs +++ b/compiler/gen/src/llvm/mod.rs @@ -1,3 +1,4 @@ +pub mod bitcode; pub mod build; pub mod build_dict; pub mod build_hash; diff --git a/compiler/gen/tests/gen_compare.rs b/compiler/gen/tests/gen_compare.rs index 2a119d4869..10bca099ee 100644 --- a/compiler/gen/tests/gen_compare.rs +++ b/compiler/gen/tests/gen_compare.rs @@ -416,4 +416,40 @@ mod gen_num { bool ); } + + #[test] + fn list_eq_empty() { + assert_evals_to!("[] == []", true, bool); + assert_evals_to!("[] != []", false, bool); + } + + #[test] + fn list_eq_by_length() { + assert_evals_to!("[1] == []", false, bool); + assert_evals_to!("[] == [1]", false, bool); + } + + #[test] + fn list_eq_compare_pointwise() { + assert_evals_to!("[1] == [1]", true, bool); + assert_evals_to!("[2] == [1]", false, bool); + } + + #[test] + fn list_eq_nested() { + assert_evals_to!("[[1]] == [[1]]", true, bool); + assert_evals_to!("[[2]] == [[1]]", false, bool); + } + + #[test] + fn list_neq_compare_pointwise() { + assert_evals_to!("[1] != [1]", false, bool); + assert_evals_to!("[2] != [1]", true, bool); + } + + #[test] + fn list_neq_nested() { + assert_evals_to!("[[1]] != [[1]]", false, bool); + assert_evals_to!("[[2]] != [[1]]", true, bool); + } } diff --git a/compiler/gen/tests/gen_dict.rs b/compiler/gen/tests/gen_dict.rs index b837a886c0..2e2df034ca 100644 --- a/compiler/gen/tests/gen_dict.rs +++ b/compiler/gen/tests/gen_dict.rs @@ -305,4 +305,218 @@ mod gen_dict { &[RocStr] ); } + + #[test] + fn unit_values() { + assert_evals_to!( + indoc!( + r#" + myDict : Dict I64 {} + myDict = + Dict.empty + |> Dict.insert 0 {} + |> Dict.insert 1 {} + |> Dict.insert 2 {} + |> Dict.insert 3 {} + + Dict.len myDict + "# + ), + 4, + i64 + ); + } + + #[test] + fn singleton() { + assert_evals_to!( + indoc!( + r#" + myDict : Dict I64 {} + myDict = + Dict.singleton 0 {} + + Dict.len myDict + "# + ), + 1, + i64 + ); + } + + #[test] + fn union() { + assert_evals_to!( + indoc!( + r#" + myDict : Dict I64 {} + myDict = + Dict.union (Dict.singleton 0 {}) (Dict.singleton 1 {}) + + Dict.len myDict + "# + ), + 2, + i64 + ); + } + + #[test] + fn union_prefer_first() { + assert_evals_to!( + indoc!( + r#" + myDict : Dict I64 I64 + myDict = + Dict.union (Dict.singleton 0 100) (Dict.singleton 0 200) + + Dict.values myDict + "# + ), + &[100], + &[i64] + ); + } + + #[test] + fn intersection() { + assert_evals_to!( + indoc!( + r#" + dict1 : Dict I64 {} + dict1 = + Dict.empty + |> Dict.insert 1 {} + |> Dict.insert 2 {} + |> Dict.insert 3 {} + |> Dict.insert 4 {} + |> Dict.insert 5 {} + + dict2 : Dict I64 {} + dict2 = + Dict.empty + |> Dict.insert 0 {} + |> Dict.insert 2 {} + |> Dict.insert 4 {} + + Dict.intersection dict1 dict2 + |> Dict.len + "# + ), + 2, + i64 + ); + } + + #[test] + fn intersection_prefer_first() { + assert_evals_to!( + indoc!( + r#" + dict1 : Dict I64 I64 + dict1 = + Dict.empty + |> Dict.insert 1 1 + |> Dict.insert 2 2 + |> Dict.insert 3 3 + |> Dict.insert 4 4 + |> Dict.insert 5 5 + + dict2 : Dict I64 I64 + dict2 = + Dict.empty + |> Dict.insert 0 100 + |> Dict.insert 2 200 + |> Dict.insert 4 300 + + Dict.intersection dict1 dict2 + |> Dict.values + "# + ), + &[4, 2], + &[i64] + ); + } + + #[test] + fn difference() { + assert_evals_to!( + indoc!( + r#" + dict1 : Dict I64 {} + dict1 = + Dict.empty + |> Dict.insert 1 {} + |> Dict.insert 2 {} + |> Dict.insert 3 {} + |> Dict.insert 4 {} + |> Dict.insert 5 {} + + dict2 : Dict I64 {} + dict2 = + Dict.empty + |> Dict.insert 0 {} + |> Dict.insert 2 {} + |> Dict.insert 4 {} + + Dict.difference dict1 dict2 + |> Dict.len + "# + ), + 3, + i64 + ); + } + + #[test] + fn difference_prefer_first() { + assert_evals_to!( + indoc!( + r#" + dict1 : Dict I64 I64 + dict1 = + Dict.empty + |> Dict.insert 1 1 + |> Dict.insert 2 2 + |> Dict.insert 3 3 + |> Dict.insert 4 4 + |> Dict.insert 5 5 + + dict2 : Dict I64 I64 + dict2 = + Dict.empty + |> Dict.insert 0 100 + |> Dict.insert 2 200 + |> Dict.insert 4 300 + + Dict.difference dict1 dict2 + |> Dict.values + "# + ), + &[5, 3, 1], + &[i64] + ); + } + + #[test] + fn walk_sum_keys() { + assert_evals_to!( + indoc!( + r#" + dict1 : Dict I64 I64 + dict1 = + Dict.empty + |> Dict.insert 1 1 + |> Dict.insert 2 2 + |> Dict.insert 3 3 + |> Dict.insert 4 4 + |> Dict.insert 5 5 + + Dict.walk dict1 (\k, _, a -> k + a) 0 + "# + ), + 15, + i64 + ); + } } diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index 403cdbee33..60adc5cf02 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -1710,39 +1710,32 @@ mod gen_list { } #[test] - fn list_eq_empty() { - assert_evals_to!("[] == []", true, bool); - assert_evals_to!("[] != []", false, bool); + fn list_keep_oks() { + assert_evals_to!("List.keepOks [] (\\x -> x)", 0, i64); + assert_evals_to!("List.keepOks [1,2] (\\x -> Ok x)", &[1, 2], &[i64]); + assert_evals_to!("List.keepOks [1,2] (\\x -> x % 2)", &[1, 0], &[i64]); + assert_evals_to!("List.keepOks [Ok 1, Err 2] (\\x -> x)", &[1], &[i64]); } #[test] - fn list_eq_by_length() { - assert_evals_to!("[1] == []", false, bool); - assert_evals_to!("[] == [1]", false, bool); + fn list_keep_errs() { + assert_evals_to!("List.keepErrs [] (\\x -> x)", 0, i64); + assert_evals_to!("List.keepErrs [1,2] (\\x -> Err x)", &[1, 2], &[i64]); + assert_evals_to!( + "List.keepErrs [0,1,2] (\\x -> x % 0 |> Result.mapErr (\\_ -> 32))", + &[32, 32, 32], + &[i64] + ); + assert_evals_to!("List.keepErrs [Ok 1, Err 2] (\\x -> x)", &[2], &[i64]); } #[test] - fn list_eq_compare_pointwise() { - assert_evals_to!("[1] == [1]", true, bool); - assert_evals_to!("[2] == [1]", false, bool); - } - - #[test] - fn list_eq_nested() { - assert_evals_to!("[[1]] == [[1]]", true, bool); - assert_evals_to!("[[2]] == [[1]]", false, bool); - } - - #[test] - fn list_neq_compare_pointwise() { - assert_evals_to!("[1] != [1]", false, bool); - assert_evals_to!("[2] != [1]", true, bool); - } - - #[test] - fn list_neq_nested() { - assert_evals_to!("[[1]] != [[1]]", false, bool); - assert_evals_to!("[[2]] != [[1]]", true, bool); + fn list_map_with_index() { + assert_evals_to!( + "List.mapWithIndex [0,0,0] (\\index, x -> index + x)", + &[0, 1, 2], + &[i64] + ); } #[test] diff --git a/compiler/gen/tests/gen_result.rs b/compiler/gen/tests/gen_result.rs new file mode 100644 index 0000000000..be63bf39ca --- /dev/null +++ b/compiler/gen/tests/gen_result.rs @@ -0,0 +1,111 @@ +#[macro_use] +extern crate pretty_assertions; +#[macro_use] +extern crate indoc; + +extern crate bumpalo; +extern crate inkwell; +extern crate libc; +extern crate roc_gen; + +#[macro_use] +mod helpers; + +#[cfg(test)] +mod gen_result { + + #[test] + fn with_default() { + assert_evals_to!( + indoc!( + r#" + result : Result I64 {} + result = Ok 2 + + Result.withDefault result 0 + "# + ), + 2, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + result : Result I64 {} + result = Err {} + + Result.withDefault result 0 + "# + ), + 0, + i64 + ); + } + + #[test] + fn result_map() { + assert_evals_to!( + indoc!( + r#" + result : Result I64 {} + result = Ok 2 + + result + |> Result.map (\x -> x + 1) + |> Result.withDefault 0 + "# + ), + 3, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + result : Result I64 {} + result = Err {} + + result + |> Result.map (\x -> x + 1) + |> Result.withDefault 0 + "# + ), + 0, + i64 + ); + } + + #[test] + fn result_map_err() { + assert_evals_to!( + indoc!( + r#" + result : Result {} I64 + result = Err 2 + + when Result.mapErr result (\x -> x + 1) is + Err n -> n + Ok _ -> 0 + "# + ), + 3, + i64 + ); + + assert_evals_to!( + indoc!( + r#" + result : Result {} I64 + result = Ok {} + + when Result.mapErr result (\x -> x + 1) is + Err n -> n + Ok _ -> 0 + "# + ), + 0, + i64 + ); + } +} diff --git a/compiler/gen/tests/gen_set.rs b/compiler/gen/tests/gen_set.rs new file mode 100644 index 0000000000..012eaf64ab --- /dev/null +++ b/compiler/gen/tests/gen_set.rs @@ -0,0 +1,248 @@ +#[macro_use] +extern crate pretty_assertions; +#[macro_use] +extern crate indoc; + +extern crate bumpalo; +extern crate inkwell; +extern crate libc; +extern crate roc_gen; + +#[macro_use] +mod helpers; + +#[cfg(test)] +mod gen_set { + + #[test] + fn empty_len() { + assert_evals_to!( + indoc!( + r#" + Set.len Set.empty + "# + ), + 0, + usize + ); + } + + #[test] + fn singleton_len() { + assert_evals_to!( + indoc!( + r#" + Set.len (Set.singleton 42) + "# + ), + 1, + usize + ); + } + + #[test] + fn singleton_to_list() { + assert_evals_to!( + indoc!( + r#" + Set.toList (Set.singleton 42) + "# + ), + &[42], + &[i64] + ); + + assert_evals_to!( + indoc!( + r#" + Set.toList (Set.singleton 1) + "# + ), + &[1], + &[i64] + ); + + assert_evals_to!( + indoc!( + r#" + Set.toList (Set.singleton 1.0) + "# + ), + &[1.0], + &[f64] + ); + } + + #[test] + fn insert() { + assert_evals_to!( + indoc!( + r#" + Set.empty + |> Set.insert 0 + |> Set.insert 1 + |> Set.insert 2 + |> Set.toList + "# + ), + &[0, 1, 2], + &[i64] + ); + } + + #[test] + fn remove() { + assert_evals_to!( + indoc!( + r#" + Set.empty + |> Set.insert 0 + |> Set.insert 1 + |> Set.remove 1 + |> Set.remove 2 + |> Set.toList + "# + ), + &[0], + &[i64] + ); + } + + #[test] + fn union() { + assert_evals_to!( + indoc!( + r#" + set1 : Set I64 + set1 = Set.fromList [1,2] + + set2 : Set I64 + set2 = Set.fromList [1,3,4] + + Set.union set1 set2 + |> Set.toList + "# + ), + &[4, 2, 3, 1], + &[i64] + ); + } + + #[test] + fn difference() { + assert_evals_to!( + indoc!( + r#" + set1 : Set I64 + set1 = Set.fromList [1,2] + + set2 : Set I64 + set2 = Set.fromList [1,3,4] + + Set.difference set1 set2 + |> Set.toList + "# + ), + &[2], + &[i64] + ); + } + + #[test] + fn intersection() { + assert_evals_to!( + indoc!( + r#" + set1 : Set I64 + set1 = Set.fromList [1,2] + + set2 : Set I64 + set2 = Set.fromList [1,3,4] + + Set.intersection set1 set2 + |> Set.toList + "# + ), + &[1], + &[i64] + ); + } + + #[test] + fn walk_sum() { + assert_evals_to!( + indoc!( + r#" + Set.walk (Set.fromList [1,2,3]) (\x, y -> x + y) 0 + "# + ), + 6, + i64 + ); + } + + #[test] + fn contains() { + assert_evals_to!( + indoc!( + r#" + Set.contains (Set.fromList [1,3,4]) 4 + "# + ), + true, + bool + ); + + assert_evals_to!( + indoc!( + r#" + Set.contains (Set.fromList [1,3,4]) 2 + "# + ), + false, + bool + ); + } + + #[test] + fn from_list() { + assert_evals_to!( + indoc!( + r#" + [1,2,2,3,1,4] + |> Set.fromList + |> Set.toList + "# + ), + &[4, 2, 3, 1], + &[i64] + ); + + assert_evals_to!( + indoc!( + r#" + [] + |> Set.fromList + |> Set.toList + "# + ), + &[], + &[i64] + ); + + assert_evals_to!( + indoc!( + r#" + empty : List I64 + empty = [] + + empty + |> Set.fromList + |> Set.toList + "# + ), + &[], + &[i64] + ); + } +} diff --git a/compiler/gen/tests/gen_str.rs b/compiler/gen/tests/gen_str.rs index c405fbd6bd..62230dc73a 100644 --- a/compiler/gen/tests/gen_str.rs +++ b/compiler/gen/tests/gen_str.rs @@ -592,4 +592,9 @@ mod gen_str { fn str_join_comma_single() { assert_evals_to!(r#"Str.joinWith ["1"] ", " "#, RocStr::from("1"), RocStr); } + + #[test] + fn str_from_float() { + assert_evals_to!(r#"Str.fromFloat 3.14"#, RocStr::from("3.140000"), RocStr); + } } diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index e6f3f962df..61c48332a4 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -275,7 +275,7 @@ pub fn helper<'a>( mode, ); - fn_val.print_to_stderr(); + // fn_val.print_to_stderr(); // module.print_to_stderr(); panic!( diff --git a/compiler/load/tests/fixtures/build/app_with_deps/AStar.roc b/compiler/load/tests/fixtures/build/app_with_deps/AStar.roc index b93c6937f3..64d66c1e78 100644 --- a/compiler/load/tests/fixtures/build/app_with_deps/AStar.roc +++ b/compiler/load/tests/fixtures/build/app_with_deps/AStar.roc @@ -42,7 +42,7 @@ cheapestOpen = \costFunction, model -> else Ok smallestSoFar - Set.foldl model.openSet folder (Err KeyNotFound) + Set.walk model.openSet folder (Err KeyNotFound) |> Result.map (\x -> x.position) @@ -101,11 +101,11 @@ astar = \costFn, moveFn, goal, model -> neighbours = moveFn current - newNeighbours = Set.diff neighbours modelPopped.evaluated + newNeighbours = Set.difference neighbours modelPopped.evaluated modelWithNeighbours = { modelPopped & openSet : Set.union modelPopped.openSet newNeighbours } - modelWithCosts = Set.foldl newNeighbours (\nb, md -> updateCost current nb md) modelWithNeighbours + modelWithCosts = Set.walk newNeighbours (\nb, md -> updateCost current nb md) modelWithNeighbours astar costFn moveFn goal modelWithCosts diff --git a/compiler/load/tests/fixtures/build/interface_with_deps/AStar.roc b/compiler/load/tests/fixtures/build/interface_with_deps/AStar.roc index 4ddbb5946b..40932f68dc 100644 --- a/compiler/load/tests/fixtures/build/interface_with_deps/AStar.roc +++ b/compiler/load/tests/fixtures/build/interface_with_deps/AStar.roc @@ -42,7 +42,7 @@ cheapestOpen = \costFunction, model -> else Ok smallestSoFar - Set.foldl model.openSet folder (Err KeyNotFound) + Set.walk model.openSet folder (Err KeyNotFound) |> Result.map (\x -> x.position) @@ -101,11 +101,11 @@ astar = \costFn, moveFn, goal, model -> neighbours = moveFn current - newNeighbours = Set.diff neighbours modelPopped.evaluated + newNeighbours = Set.difference neighbours modelPopped.evaluated modelWithNeighbours = { modelPopped & openSet : Set.union modelPopped.openSet newNeighbours } - modelWithCosts = Set.foldl newNeighbours (\nb, md -> updateCost current nb md) modelWithNeighbours + modelWithCosts = Set.walk newNeighbours (\nb, md -> updateCost current nb md) modelWithNeighbours astar costFn moveFn goal modelWithCosts diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index 3416f0a073..e69fa0dd02 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -11,6 +11,7 @@ pub enum LowLevel { StrSplit, StrCountGraphemes, StrFromInt, + StrFromFloat, ListLen, ListGetUnsafe, ListSet, @@ -24,10 +25,13 @@ pub enum LowLevel { ListPrepend, ListJoin, ListMap, + ListMapWithIndex, ListKeepIf, ListWalk, ListWalkBackwards, ListSum, + ListKeepOks, + ListKeepErrs, DictSize, DictEmpty, DictInsert, @@ -36,6 +40,11 @@ pub enum LowLevel { DictGetUnsafe, DictKeys, DictValues, + DictUnion, + DictIntersection, + DictDifference, + DictWalk, + SetFromList, NumAdd, NumAddWrap, NumAddChecked, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 0a745dd6e3..25fb2cb0fb 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -749,6 +749,13 @@ define_builtins! { 17 GENERIC_RC_REF: "#generic_rc_by_ref" // refcount of arbitrary layouts, passed as an opaque pointer 18 GENERIC_EQ: "#generic_eq" // internal function that checks generic equality + + // a user-defined function that we need to capture in a closure + // see e.g. Set.walk + 19 USER_FUNCTION: "#user_function" + + // A caller (wrapper) that we pass to zig for it to be able to call Roc functions + 20 ZIG_FUNCTION_CALLER: "#zig_function_caller" } 1 NUM: "Num" => { 0 NUM_NUM: "Num" imported // the Num.Num type alias @@ -865,6 +872,7 @@ define_builtins! { 8 STR_STARTS_WITH: "startsWith" 9 STR_ENDS_WITH: "endsWith" 10 STR_FROM_INT: "fromInt" + 11 STR_FROM_FLOAT: "fromFloat" } 4 LIST: "List" => { 0 LIST_LIST: "List" imported // the List.List type alias @@ -888,10 +896,15 @@ define_builtins! { 18 LIST_SUM: "sum" 19 LIST_WALK: "walk" 20 LIST_LAST: "last" + 21 LIST_KEEP_OKS: "keepOks" + 22 LIST_KEEP_ERRS: "keepErrs" + 23 LIST_MAP_WITH_INDEX: "mapWithIndex" } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias 1 RESULT_MAP: "map" + 2 RESULT_MAP_ERR: "mapErr" + 3 RESULT_WITH_DEFAULT: "withDefault" } 6 DICT: "Dict" => { 0 DICT_DICT: "Dict" imported // the Dict.Dict type alias @@ -910,17 +923,30 @@ define_builtins! { 9 DICT_CONTAINS: "contains" 10 DICT_KEYS: "keys" 11 DICT_VALUES: "values" + + 12 DICT_UNION: "union" + 13 DICT_INTERSECTION: "intersection" + 14 DICT_DIFFERENCE: "difference" + + 15 DICT_WALK: "walk" + } 7 SET: "Set" => { 0 SET_SET: "Set" imported // the Set.Set type alias 1 SET_AT_SET: "@Set" // the Set.@Set private tag 2 SET_EMPTY: "empty" 3 SET_SINGLETON: "singleton" - 4 SET_UNION: "union" - 5 SET_FOLDL: "foldl" - 6 SET_INSERT: "insert" - 7 SET_REMOVE: "remove" - 8 SET_DIFF: "diff" + 4 SET_LEN: "len" + 5 SET_INSERT: "insert" + 6 SET_REMOVE: "remove" + 7 SET_UNION: "union" + 8 SET_DIFFERENCE: "difference" + 9 SET_INTERSECTION: "intersection" + 10 SET_TO_LIST: "toList" + 11 SET_FROM_LIST: "fromList" + 12 SET_WALK: "walk" + 13 SET_WALK_USER_FUNCTION: "#walk_user_function" + 14 SET_CONTAINS: "contains" } num_modules: 8 // Keep this count up to date by hand! (TODO: see the mut_map! macro for how we could determine this count correctly in the macro) diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 6304f1cf64..f142d01368 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -640,11 +640,11 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ListPrepend => arena.alloc_slice_copy(&[owned, owned]), StrJoinWith => arena.alloc_slice_copy(&[irrelevant, irrelevant]), ListJoin => arena.alloc_slice_copy(&[irrelevant]), - ListMap => arena.alloc_slice_copy(&[owned, irrelevant]), - ListKeepIf => arena.alloc_slice_copy(&[owned, irrelevant]), + ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, irrelevant]), + ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, irrelevant]), ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), - ListWalk => arena.alloc_slice_copy(&[borrowed, irrelevant, owned]), - ListWalkBackwards => arena.alloc_slice_copy(&[borrowed, irrelevant, owned]), + ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]), + ListWalkBackwards => arena.alloc_slice_copy(&[owned, irrelevant, owned]), ListSum => arena.alloc_slice_copy(&[borrowed]), // TODO when we have lists with capacity (if ever) @@ -661,7 +661,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { arena.alloc_slice_copy(&[irrelevant]) } StrStartsWith | StrEndsWith => arena.alloc_slice_copy(&[owned, borrowed]), - StrFromInt => arena.alloc_slice_copy(&[irrelevant]), + StrFromInt | StrFromFloat => arena.alloc_slice_copy(&[irrelevant]), Hash => arena.alloc_slice_copy(&[borrowed, irrelevant]), DictSize => arena.alloc_slice_copy(&[borrowed]), DictEmpty => &[], @@ -670,5 +670,11 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { DictContains => arena.alloc_slice_copy(&[borrowed, borrowed]), DictGetUnsafe => arena.alloc_slice_copy(&[borrowed, borrowed]), DictKeys | DictValues => arena.alloc_slice_copy(&[borrowed]), + DictUnion | DictDifference | DictIntersection => arena.alloc_slice_copy(&[owned, borrowed]), + + // borrow function argument so we don't have to worry about RC of the closure + DictWalk => arena.alloc_slice_copy(&[owned, borrowed, owned]), + + SetFromList => arena.alloc_slice_copy(&[owned]), } } diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 6167aa55b7..0b94f7a489 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -6180,20 +6180,34 @@ fn call_by_name<'a>( None if assigned.module_id() != proc_name.module_id() => { add_needed_external(procs, env, original_fn_var, proc_name); - debug_assert_eq!( - arg_layouts.len(), - field_symbols.len(), - "scroll up a bit for background" - ); - - let call = self::Call { - call_type: CallType::ByName { - name: proc_name, - ret_layout: ret_layout.clone(), - full_layout: full_layout.clone(), - arg_layouts, - }, - arguments: field_symbols, + let call = if proc_name.module_id() == ModuleId::ATTR { + // the callable is one of the ATTR::ARG_n symbols + // we must call those by-pointer + self::Call { + call_type: CallType::ByPointer { + name: proc_name, + ret_layout: ret_layout.clone(), + full_layout: full_layout.clone(), + arg_layouts, + }, + arguments: field_symbols, + } + } else { + debug_assert_eq!( + arg_layouts.len(), + field_symbols.len(), + "scroll up a bit for background {:?}", + proc_name + ); + self::Call { + call_type: CallType::ByName { + name: proc_name, + ret_layout: ret_layout.clone(), + full_layout: full_layout.clone(), + arg_layouts, + }, + arguments: field_symbols, + } }; let result = diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index c35985c63b..668ba8f72f 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -1112,6 +1112,7 @@ fn layout_from_flat_type<'a>( other => Ok(other), } } + Symbol::SET_SET => dict_layout_from_key_value(env, args[0], Variable::EMPTY_RECORD), _ => { panic!("TODO layout_from_flat_type for {:?}", Apply(symbol, args)); } diff --git a/compiler/solve/tests/solve_expr.rs b/compiler/solve/tests/solve_expr.rs index d9cc0bccff..2013a74286 100644 --- a/compiler/solve/tests/solve_expr.rs +++ b/compiler/solve/tests/solve_expr.rs @@ -3257,7 +3257,7 @@ mod solve_expr { else Ok { position, cost: 0.0 } - Set.foldl model.openSet folder (Ok { position: boom {}, cost: 0.0 }) + Set.walk model.openSet folder (Ok { position: boom {}, cost: 0.0 }) |> Result.map (\x -> x.position) astar : Model position -> Result position [ KeyNotFound ]* diff --git a/examples/benchmarks/AStar.roc b/examples/benchmarks/AStar.roc new file mode 100644 index 0000000000..78f0022bf5 --- /dev/null +++ b/examples/benchmarks/AStar.roc @@ -0,0 +1,128 @@ +interface AStar exposes [ findPath, Model, initialModel ] imports [Quicksort] + +findPath = \costFn, moveFn, start, end -> + astar costFn moveFn end (initialModel start) + +Model position : + { + evaluated : Set position, + openSet : Set position, + costs : Dict position F64, + cameFrom : Dict position position + } + +initialModel : position -> Model position +initialModel = \start -> + { + evaluated : Set.empty, + openSet : Set.singleton start, + costs : Dict.singleton start 0, + cameFrom : Dict.empty + } + + +cheapestOpen : (position -> F64), Model position -> Result position {} +cheapestOpen = \costFn, model -> + model.openSet + |> Set.toList + |> List.keepOks (\position -> + when Dict.get model.costs position is + Err _ -> + Err {} + + Ok cost -> + Ok { cost: cost + costFn position, position } + ) + |> Quicksort.sortBy .cost + |> List.first + |> Result.map .position + |> Result.mapErr (\_ -> {}) + + +reconstructPath : Dict position position, position -> List position +reconstructPath = \cameFrom, goal -> + when Dict.get cameFrom goal is + Err _ -> + [] + + Ok next -> + List.append (reconstructPath cameFrom next) goal + +updateCost : position, position, Model position -> Model position +updateCost = \current, neighbor, model -> + when Dict.get model.costs neighbor is + Err _ -> + newCameFrom = + Dict.insert model.cameFrom neighbor current + + newCosts = + Dict.insert model.costs neighbor distanceTo + + distanceTo = + reconstructPath newCameFrom neighbor + |> List.len + |> Num.toFloat + + { model & + costs: newCosts, + cameFrom: newCameFrom + } + + Ok previousDistance -> + + newCameFrom = + Dict.insert model.cameFrom neighbor current + + newCosts = + Dict.insert model.costs neighbor distanceTo + + distanceTo = + reconstructPath newCameFrom neighbor + |> List.len + |> Num.toFloat + + newModel = + { model & + costs: newCosts, + cameFrom: newCameFrom + } + + + if distanceTo < previousDistance then + newModel + + else + model + +astar : (position, position -> F64), (position -> Set position), position, Model position -> Result (List position) {} +astar = \costFn, moveFn, goal, model -> + when cheapestOpen (\source -> costFn source goal) model is + Err {} -> + Err {} + + Ok current -> + if current == goal then + Ok (reconstructPath model.cameFrom goal) + + else + modelPopped = + { model & + openSet: Set.remove model.openSet current, + evaluated: Set.insert model.evaluated current, + } + + neighbors = + moveFn current + + newNeighbors = + Set.difference neighbors modelPopped.evaluated + + modelWithNeighbors = + { modelPopped & + openSet: Set.union modelPopped.openSet newNeighbors + } + + modelWithCosts = + Set.walk newNeighbors (\n, m -> updateCost current n m) modelWithNeighbors + + astar costFn moveFn goal modelWithCosts diff --git a/examples/benchmarks/AStarTests.roc b/examples/benchmarks/AStarTests.roc new file mode 100644 index 0000000000..84ee16b823 --- /dev/null +++ b/examples/benchmarks/AStarTests.roc @@ -0,0 +1,46 @@ +app "astar-tests" + packages { base: "platform" } + imports [base.Task, AStar] + provides [ main ] to base + +fromList : List a -> Set a +fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty + + +main : Task.Task {} [] +main = + Task.after Task.getInt \n -> + when n is + 1 -> + Task.putLine (showBool test1) + + _ -> + ns = Str.fromInt n + Task.putLine "No test \(ns)" + +showBool : Bool -> Str +showBool = \b -> + when b is + True -> "True" + False -> "False" + +test1 : Bool +test1 = + example1 == [3, 4] + +example1 : List I64 +example1 = + step : I64 -> Set I64 + step = \n -> + when n is + 1 -> fromList [ 2,3 ] + 2 -> fromList [4] + 3 -> fromList [4] + _ -> fromList [] + + cost : I64, I64 -> F64 + cost = \_, _ -> 1 + + when AStar.findPath cost step 1 4 is + Ok path -> path + Err _ -> []