Merge remote-tracking branch 'origin/trunk' into parse-closure

This commit is contained in:
Folkert 2021-02-26 13:04:54 +01:00
commit 647cbf4aaa
21 changed files with 975 additions and 258 deletions

View file

@ -231,14 +231,26 @@ mod cli_run {
#[serial(astar)] #[serial(astar)]
fn run_astar_optimized_1() { fn run_astar_optimized_1() {
check_output( check_output(
&example_file("benchmarks", "AStarTests.roc"), &example_file("benchmarks", "TestAStar.roc"),
"astar-tests", "test-astar",
&[], &[],
"True\n", "True\n",
false, false,
); );
} }
#[test]
#[serial(base64)]
fn base64() {
check_output(
&example_file("benchmarks", "TestBase64.roc"),
"test-base64",
&[],
"SGVsbG8gV29ybGQ=\n",
true,
);
}
#[test] #[test]
#[serial(closure)] #[serial(closure)]
fn closure() { fn closure() {

View file

@ -67,7 +67,8 @@ comptime {
exportStrFn(str.strFromIntC, "from_int"); exportStrFn(str.strFromIntC, "from_int");
exportStrFn(str.strFromFloatC, "from_float"); exportStrFn(str.strFromFloatC, "from_float");
exportStrFn(str.strEqual, "equal"); exportStrFn(str.strEqual, "equal");
exportStrFn(str.validateUtf8Bytes, "validate_utf8_bytes"); exportStrFn(str.strToBytesC, "to_bytes");
exportStrFn(str.fromUtf8C, "from_utf8");
} }
// Export helpers - Must be run inside a comptime // Export helpers - Must be run inside a comptime

View file

@ -1,4 +1,5 @@
const utils = @import("utils.zig"); const utils = @import("utils.zig");
const RocList = @import("list.zig").RocList;
const std = @import("std"); const std = @import("std");
const mem = std.mem; const mem = std.mem;
const always_inline = std.builtin.CallOptions.Modifier.always_inline; const always_inline = std.builtin.CallOptions.Modifier.always_inline;
@ -14,6 +15,7 @@ const InPlace = packed enum(u8) {
Clone, Clone,
}; };
const SMALL_STR_MAX_LENGTH = small_string_size - 1;
const small_string_size = 2 * @sizeOf(usize); const small_string_size = 2 * @sizeOf(usize);
const blank_small_string: [16]u8 = init_blank_small_string(small_string_size); const blank_small_string: [16]u8 = init_blank_small_string(small_string_size);
@ -961,6 +963,91 @@ test "RocStr.joinWith: result is big" {
expect(roc_result.eq(result)); expect(roc_result.eq(result));
} }
// Str.toBytes
pub fn strToBytesC(arg: RocStr) callconv(.C) RocList {
return @call(.{ .modifier = always_inline }, strToBytes, .{ std.heap.c_allocator, arg });
}
fn strToBytes(allocator: *Allocator, arg: RocStr) RocList {
if (arg.isEmpty()) {
return RocList.empty();
} else if (arg.isSmallStr()) {
const length = arg.len();
const ptr = utils.allocateWithRefcount(allocator, @alignOf(usize), length);
@memcpy(ptr, arg.asU8ptr(), length);
return RocList{ .length = length, .bytes = ptr };
} else {
return RocList{ .length = arg.len(), .bytes = arg.str_bytes };
}
}
const FromUtf8Result = extern struct {
byte_index: usize,
string: RocStr,
is_ok: bool,
problem_code: Utf8ByteProblem,
};
pub fn fromUtf8C(arg: RocList, output: *FromUtf8Result) callconv(.C) void {
output.* = @call(.{ .modifier = always_inline }, fromUtf8, .{ std.heap.c_allocator, arg });
}
fn fromUtf8(allocator: *Allocator, arg: RocList) FromUtf8Result {
const bytes = @ptrCast([*]const u8, arg.bytes)[0..arg.length];
if (unicode.utf8ValidateSlice(bytes)) {
// the output will be correct. Now we need to take ownership of the input
if (arg.len() <= SMALL_STR_MAX_LENGTH) {
// turn the bytes into a small string
const string = RocStr.init(allocator, @ptrCast([*]u8, arg.bytes), arg.len());
// then decrement the input list
const data_bytes = arg.len();
utils.decref(allocator, @alignOf(usize), arg.bytes, data_bytes);
return FromUtf8Result{ .is_ok = true, .string = string, .byte_index = 0, .problem_code = Utf8ByteProblem.InvalidStartByte };
} else {
const byte_list = arg.makeUnique(allocator, @alignOf(usize), @sizeOf(u8));
const string = RocStr{ .str_bytes = byte_list.bytes, .str_len = byte_list.length };
return FromUtf8Result{ .is_ok = true, .string = string, .byte_index = 0, .problem_code = Utf8ByteProblem.InvalidStartByte };
}
} else {
const temp = errorToProblem(@ptrCast([*]u8, arg.bytes), arg.length);
// consume the input list
const data_bytes = arg.len();
utils.decref(allocator, @alignOf(usize), arg.bytes, data_bytes);
return FromUtf8Result{ .is_ok = false, .string = RocStr.empty(), .byte_index = temp.index, .problem_code = temp.problem };
}
}
fn errorToProblem(bytes: [*]u8, length: usize) struct { index: usize, problem: Utf8ByteProblem } {
var index: usize = 0;
while (index < length) {
const nextNumBytes = numberOfNextCodepointBytes(bytes, length, index) catch |err| {
switch (err) {
error.UnexpectedEof => {
return .{ .index = index, .problem = Utf8ByteProblem.UnexpectedEndOfSequence };
},
error.Utf8InvalidStartByte => return .{ .index = index, .problem = Utf8ByteProblem.InvalidStartByte },
error.Utf8ExpectedContinuation => return .{ .index = index, .problem = Utf8ByteProblem.ExpectedContinuation },
error.Utf8OverlongEncoding => return .{ .index = index, .problem = Utf8ByteProblem.OverlongEncoding },
error.Utf8EncodesSurrogateHalf => return .{ .index = index, .problem = Utf8ByteProblem.EncodesSurrogateHalf },
error.Utf8CodepointTooLarge => return .{ .index = index, .problem = Utf8ByteProblem.CodepointTooLarge },
}
};
index += nextNumBytes;
}
unreachable;
}
pub fn isValidUnicode(ptr: [*]u8, len: usize) callconv(.C) bool { pub fn isValidUnicode(ptr: [*]u8, len: usize) callconv(.C) bool {
const bytes: []u8 = ptr[0..len]; const bytes: []u8 = ptr[0..len];
return @call(.{ .modifier = always_inline }, unicode.utf8ValidateSlice, .{bytes}); return @call(.{ .modifier = always_inline }, unicode.utf8ValidateSlice, .{bytes});
@ -998,174 +1085,170 @@ pub const Utf8ByteProblem = packed enum(u8) {
OverlongEncoding = 4, OverlongEncoding = 4,
UnexpectedEndOfSequence = 5, UnexpectedEndOfSequence = 5,
}; };
pub const ValidateUtf8BytesResult = extern struct {
is_ok: bool, byte_index: usize, problem_code: Utf8ByteProblem
};
const is_ok_utf8_byte_response = fn validateUtf8Bytes(bytes: [*]u8, length: usize) FromUtf8Result {
ValidateUtf8BytesResult{ .is_ok = true, .byte_index = 0, .problem_code = Utf8ByteProblem.UnexpectedEndOfSequence }; return fromUtf8(std.testing.allocator, RocList{ .bytes = bytes, .length = length });
inline fn toErrUtf8ByteResponse(byte_index: usize, problem_code: Utf8ByteProblem) ValidateUtf8BytesResult {
return ValidateUtf8BytesResult{ .is_ok = false, .byte_index = byte_index, .problem_code = problem_code };
} }
// Validate that an array of bytes is valid UTF-8, but if it fails catch & return the error & byte index fn validateUtf8BytesX(str: RocList) FromUtf8Result {
pub fn validateUtf8Bytes(ptr: [*]u8, len: usize) callconv(.C) ValidateUtf8BytesResult { return fromUtf8(std.testing.allocator, str);
var index: usize = 0;
while (index < len) {
const nextNumBytes = numberOfNextCodepointBytes(ptr, len, index) catch |err| {
return toErrUtf8ByteResponse(
index,
switch (err) {
error.UnexpectedEof => Utf8ByteProblem.UnexpectedEndOfSequence,
error.Utf8InvalidStartByte => Utf8ByteProblem.InvalidStartByte,
error.Utf8ExpectedContinuation => Utf8ByteProblem.ExpectedContinuation,
error.Utf8OverlongEncoding => Utf8ByteProblem.OverlongEncoding,
error.Utf8EncodesSurrogateHalf => Utf8ByteProblem.EncodesSurrogateHalf,
error.Utf8CodepointTooLarge => Utf8ByteProblem.CodepointTooLarge,
},
);
};
index += nextNumBytes;
}
return is_ok_utf8_byte_response;
} }
fn expectOk(result: FromUtf8Result) void {
expectEqual(result.is_ok, true);
}
fn sliceHelp(bytes: [*]const u8, length: usize) RocList {
var list = RocList.allocate(testing.allocator, @alignOf(usize), length, @sizeOf(u8));
@memcpy(list.bytes orelse unreachable, bytes, length);
list.length = length;
return list;
}
fn toErrUtf8ByteResponse(index: usize, problem: Utf8ByteProblem) FromUtf8Result {
return FromUtf8Result{ .is_ok = false, .string = RocStr.empty(), .byte_index = index, .problem_code = problem };
}
// NOTE on memory: the validate function consumes a RC token of the input. Since
// we freshly created it (in `sliceHelp`), it has only one RC token, and input list will be deallocated.
//
// If we tested with big strings, we'd have to deallocate the output string, but never the input list
test "validateUtf8Bytes: ascii" { test "validateUtf8Bytes: ascii" {
const str_len = 3; const raw = "abc";
var str: [str_len]u8 = "abc".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectEqual(is_ok_utf8_byte_response, validateUtf8Bytes(str_ptr, str_len)); expectOk(validateUtf8BytesX(list));
} }
test "validateUtf8Bytes: unicode œ" { test "validateUtf8Bytes: unicode œ" {
const str_len = 2; const raw = "œ";
var str: [str_len]u8 = "œ".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectEqual(is_ok_utf8_byte_response, validateUtf8Bytes(str_ptr, str_len)); expectOk(validateUtf8BytesX(list));
} }
test "validateUtf8Bytes: unicode ∆" { test "validateUtf8Bytes: unicode ∆" {
const str_len = 3; const raw = "";
var str: [str_len]u8 = "".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectEqual(is_ok_utf8_byte_response, validateUtf8Bytes(str_ptr, str_len)); expectOk(validateUtf8BytesX(list));
} }
test "validateUtf8Bytes: emoji" { test "validateUtf8Bytes: emoji" {
const str_len = 4; const raw = "💖";
var str: [str_len]u8 = "💖".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectEqual(is_ok_utf8_byte_response, validateUtf8Bytes(str_ptr, str_len)); expectOk(validateUtf8BytesX(list));
} }
test "validateUtf8Bytes: unicode ∆ in middle of array" { test "validateUtf8Bytes: unicode ∆ in middle of array" {
const str_len = 9; const raw = "œb∆c¬";
var str: [str_len]u8 = "œb∆c¬".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectEqual(is_ok_utf8_byte_response, validateUtf8Bytes(str_ptr, str_len)); expectOk(validateUtf8BytesX(list));
}
fn expectErr(list: RocList, index: usize, err: Utf8DecodeError, problem: Utf8ByteProblem) void {
const str_ptr = @ptrCast([*]u8, list.bytes);
const str_len = list.length;
expectError(err, numberOfNextCodepointBytes(str_ptr, str_len, index));
expectEqual(toErrUtf8ByteResponse(index, problem), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: invalid start byte" { test "validateUtf8Bytes: invalid start byte" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L426 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L426
const str_len = 4; const raw = "ab\x80c";
var str: [str_len]u8 = "ab\x80c".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.Utf8InvalidStartByte, numberOfNextCodepointBytes(str_ptr, str_len, 2)); expectErr(list, 2, error.Utf8InvalidStartByte, Utf8ByteProblem.InvalidStartByte);
expectEqual(toErrUtf8ByteResponse(2, Utf8ByteProblem.InvalidStartByte), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: unexpected eof for 2 byte sequence" { test "validateUtf8Bytes: unexpected eof for 2 byte sequence" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L426 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L426
const str_len = 4; const raw = "abc\xc2";
var str: [str_len]u8 = "abc\xc2".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.UnexpectedEof, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.UnexpectedEof, Utf8ByteProblem.UnexpectedEndOfSequence);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.UnexpectedEndOfSequence), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: expected continuation for 2 byte sequence" { test "validateUtf8Bytes: expected continuation for 2 byte sequence" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L426 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L426
const str_len = 5; const raw = "abc\xc2\x00";
var str: [str_len]u8 = "abc\xc2\x00".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.Utf8ExpectedContinuation, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.Utf8ExpectedContinuation, Utf8ByteProblem.ExpectedContinuation);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.ExpectedContinuation), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: unexpected eof for 3 byte sequence" { test "validateUtf8Bytes: unexpected eof for 3 byte sequence" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L430 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L430
const str_len = 5; const raw = "abc\xe0\x00";
var str: [str_len]u8 = "abc\xe0\x00".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.UnexpectedEof, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.UnexpectedEof, Utf8ByteProblem.UnexpectedEndOfSequence);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.UnexpectedEndOfSequence), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: expected continuation for 3 byte sequence" { test "validateUtf8Bytes: expected continuation for 3 byte sequence" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L430 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L430
const str_len = 6; const raw = "abc\xe0\xa0\xc0";
var str: [str_len]u8 = "abc\xe0\xa0\xc0".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.Utf8ExpectedContinuation, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.Utf8ExpectedContinuation, Utf8ByteProblem.ExpectedContinuation);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.ExpectedContinuation), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: unexpected eof for 4 byte sequence" { test "validateUtf8Bytes: unexpected eof for 4 byte sequence" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L437 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L437
const str_len = 6; const raw = "abc\xf0\x90\x00";
var str: [str_len]u8 = "abc\xf0\x90\x00".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.UnexpectedEof, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.UnexpectedEof, Utf8ByteProblem.UnexpectedEndOfSequence);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.UnexpectedEndOfSequence), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: expected continuation for 4 byte sequence" { test "validateUtf8Bytes: expected continuation for 4 byte sequence" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L437 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L437
const str_len = 7; const raw = "abc\xf0\x90\x80\x00";
var str: [str_len]u8 = "abc\xf0\x90\x80\x00".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.Utf8ExpectedContinuation, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.Utf8ExpectedContinuation, Utf8ByteProblem.ExpectedContinuation);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.ExpectedContinuation), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: overlong" { test "validateUtf8Bytes: overlong" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L451 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L451
const str_len = 7; const raw = "abc\xf0\x80\x80\x80";
var str: [str_len]u8 = "abc\xf0\x80\x80\x80".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.Utf8OverlongEncoding, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.Utf8OverlongEncoding, Utf8ByteProblem.OverlongEncoding);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.OverlongEncoding), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: codepoint out too large" { test "validateUtf8Bytes: codepoint out too large" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L465 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L465
const str_len = 7; const raw = "abc\xf4\x90\x80\x80";
var str: [str_len]u8 = "abc\xf4\x90\x80\x80".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.Utf8CodepointTooLarge, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.Utf8CodepointTooLarge, Utf8ByteProblem.CodepointTooLarge);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.CodepointTooLarge), validateUtf8Bytes(str_ptr, str_len));
} }
test "validateUtf8Bytes: surrogate halves" { test "validateUtf8Bytes: surrogate halves" {
// https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L468 // https://github.com/ziglang/zig/blob/0.7.x/lib/std/unicode.zig#L468
const str_len = 6; const raw = "abc\xed\xa0\x80";
var str: [str_len]u8 = "abc\xed\xa0\x80".*; const ptr: [*]const u8 = @ptrCast([*]const u8, raw);
const str_ptr: [*]u8 = &str; const list = sliceHelp(ptr, raw.len);
expectError(error.Utf8EncodesSurrogateHalf, numberOfNextCodepointBytes(str_ptr, str_len, 3)); expectErr(list, 3, error.Utf8EncodesSurrogateHalf, Utf8ByteProblem.EncodesSurrogateHalf);
expectEqual(toErrUtf8ByteResponse(3, Utf8ByteProblem.EncodesSurrogateHalf), validateUtf8Bytes(str_ptr, str_len));
} }

View file

@ -41,7 +41,8 @@ pub const STR_NUMBER_OF_BYTES: &str = "roc_builtins.str.number_of_bytes";
pub const STR_FROM_INT: &str = "roc_builtins.str.from_int"; pub const STR_FROM_INT: &str = "roc_builtins.str.from_int";
pub const STR_FROM_FLOAT: &str = "roc_builtins.str.from_float"; pub const STR_FROM_FLOAT: &str = "roc_builtins.str.from_float";
pub const STR_EQUAL: &str = "roc_builtins.str.equal"; pub const STR_EQUAL: &str = "roc_builtins.str.equal";
pub const STR_VALIDATE_UTF_BYTES: &str = "roc_builtins.str.validate_utf8_bytes"; pub const STR_TO_BYTES: &str = "roc_builtins.str.to_bytes";
pub const STR_FROM_UTF8: &str = "roc_builtins.str.from_utf8";
pub const DICT_HASH: &str = "roc_builtins.dict.hash"; pub const DICT_HASH: &str = "roc_builtins.dict.hash";
pub const DICT_HASH_STR: &str = "roc_builtins.dict.hash_str"; pub const DICT_HASH_STR: &str = "roc_builtins.dict.hash_str";

View file

@ -324,6 +324,48 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
), ),
); );
// bitwiseOr : Int a, Int a -> Int a
add_type(
Symbol::NUM_BITWISE_OR,
top_level_function(
vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))],
Box::new(int_type(flex(TVAR1))),
),
);
// shiftLeftBy : Int a, Int a -> Int a
add_type(
Symbol::NUM_SHIFT_LEFT,
top_level_function(
vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))],
Box::new(int_type(flex(TVAR1))),
),
);
// shiftRightBy : Int a, Int a -> Int a
add_type(
Symbol::NUM_SHIFT_RIGHT,
top_level_function(
vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))],
Box::new(int_type(flex(TVAR1))),
),
);
// shiftRightZfBy : Int a, Int a -> Int a
add_type(
Symbol::NUM_SHIFT_RIGHT_ZERO_FILL,
top_level_function(
vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))],
Box::new(int_type(flex(TVAR1))),
),
);
// intCast : Int a -> Int b
add_type(
Symbol::NUM_INT_CAST,
top_level_function(vec![int_type(flex(TVAR1))], Box::new(int_type(flex(TVAR2)))),
);
// rem : Int a, Int a -> Result (Int a) [ DivByZero ]* // rem : Int a, Int a -> Result (Int a) [ DivByZero ]*
add_type( add_type(
Symbol::NUM_REM, Symbol::NUM_REM,
@ -581,6 +623,12 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
), ),
); );
// toBytes : Str -> List U8
add_type(
Symbol::STR_TO_BYTES,
top_level_function(vec![str_type()], Box::new(list_type(u8_type()))),
);
// fromFloat : Float a -> Str // fromFloat : Float a -> Str
add_type( add_type(
Symbol::STR_FROM_FLOAT, Symbol::STR_FROM_FLOAT,

View file

@ -62,6 +62,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option<Def>
STR_COUNT_GRAPHEMES => str_count_graphemes, STR_COUNT_GRAPHEMES => str_count_graphemes,
STR_FROM_INT => str_from_int, STR_FROM_INT => str_from_int,
STR_FROM_UTF8 => str_from_utf8, STR_FROM_UTF8 => str_from_utf8,
STR_TO_BYTES => str_to_bytes,
STR_FROM_FLOAT=> str_from_float, STR_FROM_FLOAT=> str_from_float,
LIST_LEN => list_len, LIST_LEN => list_len,
LIST_GET => list_get, LIST_GET => list_get,
@ -152,6 +153,11 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option<Def>
NUM_MIN_INT => num_min_int, NUM_MIN_INT => num_min_int,
NUM_BITWISE_AND => num_bitwise_and, NUM_BITWISE_AND => num_bitwise_and,
NUM_BITWISE_XOR => num_bitwise_xor, NUM_BITWISE_XOR => num_bitwise_xor,
NUM_BITWISE_OR => num_bitwise_or,
NUM_SHIFT_LEFT=> num_shift_left_by,
NUM_SHIFT_RIGHT => num_shift_right_by,
NUM_SHIFT_RIGHT_ZERO_FILL => num_shift_right_zf_by,
NUM_INT_CAST=> num_int_cast,
RESULT_MAP => result_map, RESULT_MAP => result_map,
RESULT_MAP_ERR => result_map_err, RESULT_MAP_ERR => result_map_err,
RESULT_WITH_DEFAULT => result_with_default, RESULT_WITH_DEFAULT => result_with_default,
@ -191,6 +197,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap<Symbol, Def> {
Symbol::STR_COUNT_GRAPHEMES => str_count_graphemes, Symbol::STR_COUNT_GRAPHEMES => str_count_graphemes,
Symbol::STR_FROM_INT => str_from_int, Symbol::STR_FROM_INT => str_from_int,
Symbol::STR_FROM_UTF8 => str_from_utf8, Symbol::STR_FROM_UTF8 => str_from_utf8,
Symbol::STR_TO_BYTES => str_to_bytes,
Symbol::STR_FROM_FLOAT=> str_from_float, Symbol::STR_FROM_FLOAT=> str_from_float,
Symbol::LIST_LEN => list_len, Symbol::LIST_LEN => list_len,
Symbol::LIST_GET => list_get, Symbol::LIST_GET => list_get,
@ -275,6 +282,13 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap<Symbol, Def> {
Symbol::NUM_ASIN => num_asin, Symbol::NUM_ASIN => num_asin,
Symbol::NUM_MAX_INT => num_max_int, Symbol::NUM_MAX_INT => num_max_int,
Symbol::NUM_MIN_INT => num_min_int, Symbol::NUM_MIN_INT => num_min_int,
Symbol::NUM_BITWISE_AND => num_bitwise_and,
Symbol::NUM_BITWISE_XOR => num_bitwise_xor,
Symbol::NUM_BITWISE_OR => num_bitwise_or,
Symbol::NUM_SHIFT_LEFT => num_shift_left_by,
Symbol::NUM_SHIFT_RIGHT => num_shift_right_by,
Symbol::NUM_SHIFT_RIGHT_ZERO_FILL => num_shift_right_zf_by,
Symbol::NUM_INT_CAST=> num_int_cast,
Symbol::RESULT_MAP => result_map, Symbol::RESULT_MAP => result_map,
Symbol::RESULT_MAP_ERR => result_map_err, Symbol::RESULT_MAP_ERR => result_map_err,
Symbol::RESULT_WITH_DEFAULT => result_with_default, Symbol::RESULT_WITH_DEFAULT => result_with_default,
@ -1301,6 +1315,31 @@ fn num_bitwise_xor(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_binop(symbol, var_store, LowLevel::NumBitwiseXor) num_binop(symbol, var_store, LowLevel::NumBitwiseXor)
} }
/// Num.bitwiseOr: Int, Int -> Int
fn num_bitwise_or(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_binop(symbol, var_store, LowLevel::NumBitwiseOr)
}
/// Num.shiftLeftBy: Nat, Int a -> Int a
fn num_shift_left_by(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_2(symbol, LowLevel::NumShiftLeftBy, var_store)
}
/// Num.shiftRightBy: Nat, Int a -> Int a
fn num_shift_right_by(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_2(symbol, LowLevel::NumShiftRightBy, var_store)
}
/// Num.shiftRightZfBy: Nat, Int a -> Int a
fn num_shift_right_zf_by(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_2(symbol, LowLevel::NumShiftRightZfBy, var_store)
}
/// Num.intCast: Int a -> Int b
fn num_int_cast(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_1(symbol, LowLevel::NumIntCast, var_store)
}
/// List.isEmpty : List * -> Bool /// List.isEmpty : List * -> Bool
fn list_is_empty(symbol: Symbol, var_store: &mut VarStore) -> Def { fn list_is_empty(symbol: Symbol, var_store: &mut VarStore) -> Def {
let list_var = var_store.fresh(); let list_var = var_store.fresh();
@ -1559,7 +1598,7 @@ fn str_from_utf8(symbol: Symbol, var_store: &mut VarStore) -> Def {
Access { Access {
record_var, record_var,
ext_var: var_store.fresh(), ext_var: var_store.fresh(),
field: "isOk".into(), field: "c_isOk".into(),
field_var: var_store.fresh(), field_var: var_store.fresh(),
loc_expr: Box::new(no_region(Var(Symbol::ARG_2))), loc_expr: Box::new(no_region(Var(Symbol::ARG_2))),
}, },
@ -1571,7 +1610,7 @@ fn str_from_utf8(symbol: Symbol, var_store: &mut VarStore) -> Def {
vec![Access { vec![Access {
record_var, record_var,
ext_var: var_store.fresh(), ext_var: var_store.fresh(),
field: "str".into(), field: "b_str".into(),
field_var: var_store.fresh(), field_var: var_store.fresh(),
loc_expr: Box::new(no_region(Var(Symbol::ARG_2))), loc_expr: Box::new(no_region(Var(Symbol::ARG_2))),
}], }],
@ -1588,14 +1627,14 @@ fn str_from_utf8(symbol: Symbol, var_store: &mut VarStore) -> Def {
Access { Access {
record_var, record_var,
ext_var: var_store.fresh(), ext_var: var_store.fresh(),
field: "problem".into(), field: "d_problem".into(),
field_var: var_store.fresh(), field_var: var_store.fresh(),
loc_expr: Box::new(no_region(Var(Symbol::ARG_2))), loc_expr: Box::new(no_region(Var(Symbol::ARG_2))),
}, },
Access { Access {
record_var, record_var,
ext_var: var_store.fresh(), ext_var: var_store.fresh(),
field: "byteIndex".into(), field: "a_byteIndex".into(),
field_var: var_store.fresh(), field_var: var_store.fresh(),
loc_expr: Box::new(no_region(Var(Symbol::ARG_2))), loc_expr: Box::new(no_region(Var(Symbol::ARG_2))),
}, },
@ -1618,6 +1657,11 @@ fn str_from_utf8(symbol: Symbol, var_store: &mut VarStore) -> Def {
) )
} }
/// Str.toBytes : Str -> List U8
fn str_to_bytes(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_1(symbol, LowLevel::StrToBytes, var_store)
}
/// Str.fromFloat : Float * -> Str /// Str.fromFloat : Float * -> Str
fn str_from_float(symbol: Symbol, var_store: &mut VarStore) -> Def { fn str_from_float(symbol: Symbol, var_store: &mut VarStore) -> Def {
let float_var = var_store.fresh(); let float_var = var_store.fresh();

View file

@ -12,7 +12,7 @@ use crate::llvm::build_list::{
}; };
use crate::llvm::build_str::{ use crate::llvm::build_str::{
str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8, str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8,
str_join_with, str_number_of_bytes, str_split, str_starts_with, CHAR_LAYOUT, str_join_with, str_number_of_bytes, str_split, str_starts_with, str_to_bytes, CHAR_LAYOUT,
}; };
use crate::llvm::compare::{generic_eq, generic_neq}; use crate::llvm::compare::{generic_eq, generic_neq};
use crate::llvm::convert::{ use crate::llvm::convert::{
@ -296,8 +296,10 @@ fn add_intrinsics<'ctx>(ctx: &'ctx Context, module: &Module<'ctx>) {
let void_type = ctx.void_type(); let void_type = ctx.void_type();
let i1_type = ctx.bool_type(); let i1_type = ctx.bool_type();
let f64_type = ctx.f64_type(); let f64_type = ctx.f64_type();
let i128_type = ctx.i128_type();
let i64_type = ctx.i64_type(); let i64_type = ctx.i64_type();
let i32_type = ctx.i32_type(); let i32_type = ctx.i32_type();
let i16_type = ctx.i16_type();
let i8_type = ctx.i8_type(); let i8_type = ctx.i8_type();
let i8_ptr_type = i8_type.ptr_type(AddressSpace::Generic); let i8_ptr_type = i8_type.ptr_type(AddressSpace::Generic);
@ -377,18 +379,72 @@ fn add_intrinsics<'ctx>(ctx: &'ctx Context, module: &Module<'ctx>) {
f64_type.fn_type(&[f64_type.into()], false), f64_type.fn_type(&[f64_type.into()], false),
); );
// add with overflow
add_intrinsic(module, LLVM_SADD_WITH_OVERFLOW_I8, {
let fields = [i8_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i8_type.into(), i8_type.into()], false)
});
add_intrinsic(module, LLVM_SADD_WITH_OVERFLOW_I16, {
let fields = [i16_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i16_type.into(), i16_type.into()], false)
});
add_intrinsic(module, LLVM_SADD_WITH_OVERFLOW_I32, {
let fields = [i32_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i32_type.into(), i32_type.into()], false)
});
add_intrinsic(module, LLVM_SADD_WITH_OVERFLOW_I64, { add_intrinsic(module, LLVM_SADD_WITH_OVERFLOW_I64, {
let fields = [i64_type.into(), i1_type.into()]; let fields = [i64_type.into(), i1_type.into()];
ctx.struct_type(&fields, false) ctx.struct_type(&fields, false)
.fn_type(&[i64_type.into(), i64_type.into()], false) .fn_type(&[i64_type.into(), i64_type.into()], false)
}); });
add_intrinsic(module, LLVM_SADD_WITH_OVERFLOW_I128, {
let fields = [i128_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i128_type.into(), i128_type.into()], false)
});
// sub with overflow
add_intrinsic(module, LLVM_SSUB_WITH_OVERFLOW_I8, {
let fields = [i8_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i8_type.into(), i8_type.into()], false)
});
add_intrinsic(module, LLVM_SSUB_WITH_OVERFLOW_I16, {
let fields = [i16_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i16_type.into(), i16_type.into()], false)
});
add_intrinsic(module, LLVM_SSUB_WITH_OVERFLOW_I32, {
let fields = [i32_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i32_type.into(), i32_type.into()], false)
});
add_intrinsic(module, LLVM_SSUB_WITH_OVERFLOW_I64, { add_intrinsic(module, LLVM_SSUB_WITH_OVERFLOW_I64, {
let fields = [i64_type.into(), i1_type.into()]; let fields = [i64_type.into(), i1_type.into()];
ctx.struct_type(&fields, false) ctx.struct_type(&fields, false)
.fn_type(&[i64_type.into(), i64_type.into()], false) .fn_type(&[i64_type.into(), i64_type.into()], false)
}); });
add_intrinsic(module, LLVM_SSUB_WITH_OVERFLOW_I128, {
let fields = [i128_type.into(), i1_type.into()];
ctx.struct_type(&fields, false)
.fn_type(&[i128_type.into(), i128_type.into()], false)
});
// mul with overflow
add_intrinsic(module, LLVM_SMUL_WITH_OVERFLOW_I64, { add_intrinsic(module, LLVM_SMUL_WITH_OVERFLOW_I64, {
let fields = [i64_type.into(), i1_type.into()]; let fields = [i64_type.into(), i1_type.into()];
ctx.struct_type(&fields, false) ctx.struct_type(&fields, false)
@ -406,8 +462,19 @@ static LLVM_COS_F64: &str = "llvm.cos.f64";
static LLVM_POW_F64: &str = "llvm.pow.f64"; static LLVM_POW_F64: &str = "llvm.pow.f64";
static LLVM_CEILING_F64: &str = "llvm.ceil.f64"; static LLVM_CEILING_F64: &str = "llvm.ceil.f64";
static LLVM_FLOOR_F64: &str = "llvm.floor.f64"; static LLVM_FLOOR_F64: &str = "llvm.floor.f64";
pub static LLVM_SADD_WITH_OVERFLOW_I8: &str = "llvm.sadd.with.overflow.i8";
pub static LLVM_SADD_WITH_OVERFLOW_I16: &str = "llvm.sadd.with.overflow.i16";
pub static LLVM_SADD_WITH_OVERFLOW_I32: &str = "llvm.sadd.with.overflow.i32";
pub static LLVM_SADD_WITH_OVERFLOW_I64: &str = "llvm.sadd.with.overflow.i64"; pub static LLVM_SADD_WITH_OVERFLOW_I64: &str = "llvm.sadd.with.overflow.i64";
pub static LLVM_SADD_WITH_OVERFLOW_I128: &str = "llvm.sadd.with.overflow.i128";
pub static LLVM_SSUB_WITH_OVERFLOW_I8: &str = "llvm.ssub.with.overflow.i8";
pub static LLVM_SSUB_WITH_OVERFLOW_I16: &str = "llvm.ssub.with.overflow.i16";
pub static LLVM_SSUB_WITH_OVERFLOW_I32: &str = "llvm.ssub.with.overflow.i32";
pub static LLVM_SSUB_WITH_OVERFLOW_I64: &str = "llvm.ssub.with.overflow.i64"; pub static LLVM_SSUB_WITH_OVERFLOW_I64: &str = "llvm.ssub.with.overflow.i64";
pub static LLVM_SSUB_WITH_OVERFLOW_I128: &str = "llvm.ssub.with.overflow.i128";
pub static LLVM_SMUL_WITH_OVERFLOW_I64: &str = "llvm.smul.with.overflow.i64"; pub static LLVM_SMUL_WITH_OVERFLOW_I64: &str = "llvm.smul.with.overflow.i64";
fn add_intrinsic<'ctx>( fn add_intrinsic<'ctx>(
@ -3544,13 +3611,23 @@ fn run_low_level<'a, 'ctx, 'env>(
str_from_float(env, scope, args[0]) str_from_float(env, scope, args[0])
} }
StrFromUtf8 => { StrFromUtf8 => {
// Str.fromInt : Int -> Str // Str.fromUtf8 : List U8 -> Result Str Utf8Problem
debug_assert_eq!(args.len(), 1); debug_assert_eq!(args.len(), 1);
let original_wrapper = load_symbol(scope, &args[0]).into_struct_value(); let original_wrapper = load_symbol(scope, &args[0]).into_struct_value();
str_from_utf8(env, parent, original_wrapper) str_from_utf8(env, parent, original_wrapper)
} }
StrToBytes => {
// Str.fromInt : Str -> List U8
debug_assert_eq!(args.len(), 1);
// this is an identity conversion
// we just implement it here to subvert the type system
let string = load_symbol(scope, &args[0]);
str_to_bytes(env, string.into_struct_value())
}
StrSplit => { StrSplit => {
// Str.split : Str, Str -> List Str // Str.split : Str, Str -> List Str
debug_assert_eq!(args.len(), 2); debug_assert_eq!(args.len(), 2);
@ -3951,7 +4028,7 @@ fn run_low_level<'a, 'ctx, 'env>(
build_num_binop(env, parent, lhs_arg, lhs_layout, rhs_arg, rhs_layout, op) build_num_binop(env, parent, lhs_arg, lhs_layout, rhs_arg, rhs_layout, op)
} }
NumBitwiseAnd | NumBitwiseXor => { NumBitwiseAnd | NumBitwiseOr | NumBitwiseXor => {
debug_assert_eq!(args.len(), 2); debug_assert_eq!(args.len(), 2);
let (lhs_arg, lhs_layout) = load_symbol_and_layout(scope, &args[0]); let (lhs_arg, lhs_layout) = load_symbol_and_layout(scope, &args[0]);
@ -3967,6 +4044,32 @@ fn run_low_level<'a, 'ctx, 'env>(
op, op,
) )
} }
NumShiftLeftBy | NumShiftRightBy | NumShiftRightZfBy => {
debug_assert_eq!(args.len(), 2);
let (lhs_arg, lhs_layout) = load_symbol_and_layout(scope, &args[0]);
let (rhs_arg, rhs_layout) = load_symbol_and_layout(scope, &args[1]);
build_int_binop(
env,
parent,
lhs_arg.into_int_value(),
lhs_layout,
rhs_arg.into_int_value(),
rhs_layout,
op,
)
}
NumIntCast => {
debug_assert_eq!(args.len(), 1);
let arg = load_symbol(scope, &args[0]).into_int_value();
let to = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes)
.into_int_type();
env.builder.build_int_cast(arg, to, "inc_cast").into()
}
Eq => { Eq => {
debug_assert_eq!(args.len(), 2); debug_assert_eq!(args.len(), 2);
@ -4480,7 +4583,7 @@ fn build_int_binop<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
parent: FunctionValue<'ctx>, parent: FunctionValue<'ctx>,
lhs: IntValue<'ctx>, lhs: IntValue<'ctx>,
_lhs_layout: &Layout<'a>, lhs_layout: &Layout<'a>,
rhs: IntValue<'ctx>, rhs: IntValue<'ctx>,
_rhs_layout: &Layout<'a>, _rhs_layout: &Layout<'a>,
op: LowLevel, op: LowLevel,
@ -4493,8 +4596,23 @@ fn build_int_binop<'a, 'ctx, 'env>(
match op { match op {
NumAdd => { NumAdd => {
let context = env.context; let context = env.context;
let intrinsic = match lhs_layout {
Layout::Builtin(Builtin::Int8) => LLVM_SADD_WITH_OVERFLOW_I8,
Layout::Builtin(Builtin::Int16) => LLVM_SADD_WITH_OVERFLOW_I16,
Layout::Builtin(Builtin::Int32) => LLVM_SADD_WITH_OVERFLOW_I32,
Layout::Builtin(Builtin::Int64) => LLVM_SADD_WITH_OVERFLOW_I64,
Layout::Builtin(Builtin::Int128) => LLVM_SADD_WITH_OVERFLOW_I128,
Layout::Builtin(Builtin::Usize) => match env.ptr_bytes {
4 => LLVM_SADD_WITH_OVERFLOW_I32,
8 => LLVM_SADD_WITH_OVERFLOW_I64,
other => panic!("invalid ptr_bytes {}", other),
},
_ => unreachable!(),
};
let result = env let result = env
.call_intrinsic(LLVM_SADD_WITH_OVERFLOW_I64, &[lhs.into(), rhs.into()]) .call_intrinsic(intrinsic, &[lhs.into(), rhs.into()])
.into_struct_value(); .into_struct_value();
let add_result = bd.build_extract_value(result, 0, "add_result").unwrap(); let add_result = bd.build_extract_value(result, 0, "add_result").unwrap();
@ -4524,8 +4642,23 @@ fn build_int_binop<'a, 'ctx, 'env>(
NumAddChecked => env.call_intrinsic(LLVM_SADD_WITH_OVERFLOW_I64, &[lhs.into(), rhs.into()]), NumAddChecked => env.call_intrinsic(LLVM_SADD_WITH_OVERFLOW_I64, &[lhs.into(), rhs.into()]),
NumSub => { NumSub => {
let context = env.context; let context = env.context;
let intrinsic = match lhs_layout {
Layout::Builtin(Builtin::Int8) => LLVM_SSUB_WITH_OVERFLOW_I8,
Layout::Builtin(Builtin::Int16) => LLVM_SSUB_WITH_OVERFLOW_I16,
Layout::Builtin(Builtin::Int32) => LLVM_SSUB_WITH_OVERFLOW_I32,
Layout::Builtin(Builtin::Int64) => LLVM_SSUB_WITH_OVERFLOW_I64,
Layout::Builtin(Builtin::Int128) => LLVM_SSUB_WITH_OVERFLOW_I128,
Layout::Builtin(Builtin::Usize) => match env.ptr_bytes {
4 => LLVM_SSUB_WITH_OVERFLOW_I32,
8 => LLVM_SSUB_WITH_OVERFLOW_I64,
other => panic!("invalid ptr_bytes {}", other),
},
_ => unreachable!("invalid layout {:?}", lhs_layout),
};
let result = env let result = env
.call_intrinsic(LLVM_SSUB_WITH_OVERFLOW_I64, &[lhs.into(), rhs.into()]) .call_intrinsic(intrinsic, &[lhs.into(), rhs.into()])
.into_struct_value(); .into_struct_value();
let sub_result = bd.build_extract_value(result, 0, "sub_result").unwrap(); let sub_result = bd.build_extract_value(result, 0, "sub_result").unwrap();
@ -4593,6 +4726,24 @@ fn build_int_binop<'a, 'ctx, 'env>(
NumPowInt => call_bitcode_fn(env, &[lhs.into(), rhs.into()], &bitcode::NUM_POW_INT), NumPowInt => call_bitcode_fn(env, &[lhs.into(), rhs.into()], &bitcode::NUM_POW_INT),
NumBitwiseAnd => bd.build_and(lhs, rhs, "int_bitwise_and").into(), NumBitwiseAnd => bd.build_and(lhs, rhs, "int_bitwise_and").into(),
NumBitwiseXor => bd.build_xor(lhs, rhs, "int_bitwise_xor").into(), NumBitwiseXor => bd.build_xor(lhs, rhs, "int_bitwise_xor").into(),
NumBitwiseOr => bd.build_or(lhs, rhs, "int_bitwise_or").into(),
NumShiftLeftBy => {
// NOTE arguments are flipped;
// we write `assert_eq!(0b0000_0001 << 0, 0b0000_0001);`
// as `Num.shiftLeftBy 0 0b0000_0001
bd.build_left_shift(rhs, lhs, "int_shift_left").into()
}
NumShiftRightBy => {
// NOTE arguments are flipped;
bd.build_right_shift(rhs, lhs, false, "int_shift_right")
.into()
}
NumShiftRightZfBy => {
// NOTE arguments are flipped;
bd.build_right_shift(rhs, lhs, true, "int_shift_right_zf")
.into()
}
_ => { _ => {
unreachable!("Unrecognized int binary operation: {:?}", op); unreachable!("Unrecognized int binary operation: {:?}", op);
} }

View file

@ -1,13 +1,11 @@
use crate::llvm::bitcode::{call_bitcode_fn, call_void_bitcode_fn}; use crate::llvm::bitcode::{call_bitcode_fn, call_void_bitcode_fn};
use crate::llvm::build::{complex_bitcast, Env, InPlace, Scope}; use crate::llvm::build::{complex_bitcast, Env, InPlace, Scope};
use crate::llvm::build_list::{ use crate::llvm::build_list::{allocate_list, store_list};
allocate_list, build_basic_phi2, empty_polymorphic_list, list_len, load_list_ptr, store_list, use crate::llvm::convert::collection;
};
use crate::llvm::convert::{collection, get_ptr_type};
use inkwell::builder::Builder; use inkwell::builder::Builder;
use inkwell::types::{BasicTypeEnum, StructType}; use inkwell::types::BasicTypeEnum;
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue}; use inkwell::values::{BasicValueEnum, FunctionValue, IntValue, PointerValue, StructValue};
use inkwell::{AddressSpace, IntPredicate}; use inkwell::AddressSpace;
use roc_builtins::bitcode; use roc_builtins::bitcode;
use roc_module::symbol::Symbol; use roc_module::symbol::Symbol;
use roc_mono::layout::{Builtin, Layout}; use roc_mono::layout::{Builtin, Layout};
@ -275,46 +273,53 @@ pub fn str_from_int<'a, 'ctx, 'env>(
zig_str_to_struct(env, zig_result).into() zig_str_to_struct(env, zig_result).into()
} }
/// Str.toBytes : Str -> List U8
pub fn str_to_bytes<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
original_wrapper: StructValue<'ctx>,
) -> BasicValueEnum<'ctx> {
let string = complex_bitcast(
env.builder,
original_wrapper.into(),
env.context.i128_type().into(),
"to_bytes",
);
let zig_result = call_bitcode_fn(env, &[string], &bitcode::STR_TO_BYTES);
complex_bitcast(
env.builder,
zig_result,
collection(env.context, env.ptr_bytes).into(),
"to_bytes",
)
}
/// Str.fromUtf8 : List U8 -> { a : Bool, b : Str, c : Nat, d : I8 } /// Str.fromUtf8 : List U8 -> { a : Bool, b : Str, c : Nat, d : I8 }
pub fn str_from_utf8<'a, 'ctx, 'env>( pub fn str_from_utf8<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
parent: FunctionValue<'ctx>, _parent: FunctionValue<'ctx>,
original_wrapper: StructValue<'ctx>, original_wrapper: StructValue<'ctx>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
let builder = env.builder; let builder = env.builder;
let ctx = env.context; let ctx = env.context;
let list_len = list_len(builder, original_wrapper); let result_type = env.module.get_struct_type("str.FromUtf8Result").unwrap();
let ptr_type = get_ptr_type(&ctx.i8_type().into(), AddressSpace::Generic);
let list_ptr = load_list_ptr(builder, original_wrapper, ptr_type);
let result_type = env
.module
.get_struct_type("str.ValidateUtf8BytesResult")
.unwrap();
let result_ptr = builder.build_alloca(result_type, "alloca_utf8_validate_bytes_result"); let result_ptr = builder.build_alloca(result_type, "alloca_utf8_validate_bytes_result");
call_void_bitcode_fn( call_void_bitcode_fn(
env, env,
&[result_ptr.into(), list_ptr.into(), list_len.into()], &[
&bitcode::STR_VALIDATE_UTF_BYTES, complex_bitcast(
env.builder,
original_wrapper.into(),
env.context.i128_type().into(),
"to_i128",
),
result_ptr.into(),
],
&bitcode::STR_FROM_UTF8,
); );
let utf8_validate_bytes_result = builder
.build_load(result_ptr, "load_utf8_validate_bytes_result")
.into_struct_value();
let is_ok = builder
.build_extract_value(utf8_validate_bytes_result, 0, "extract_extract_is_ok")
.unwrap()
.into_int_value();
let byte_index = builder
.build_extract_value(utf8_validate_bytes_result, 1, "extract_byte_index")
.unwrap()
.into_int_value();
let problem_code = builder
.build_extract_value(utf8_validate_bytes_result, 2, "extract_problem_code")
.unwrap()
.into_int_value();
let record_type = env.context.struct_type( let record_type = env.context.struct_type(
&[ &[
@ -326,71 +331,16 @@ pub fn str_from_utf8<'a, 'ctx, 'env>(
false, false,
); );
let comparison = builder.build_int_compare( let result_ptr_cast = env
IntPredicate::EQ, .builder
is_ok, .build_bitcast(
ctx.bool_type().const_int(1, false), result_ptr,
"compare_is_ok", record_type.ptr_type(AddressSpace::Generic),
); "to_unnamed",
)
.into_pointer_value();
build_basic_phi2( builder.build_load(result_ptr_cast, "load_utf8_validate_bytes_result")
env,
parent,
comparison,
|| {
// We have a valid utf8 byte sequence
// TODO: Should we do something different here if we're doing this in place?
let zig_str =
call_bitcode_fn(env, &[list_ptr.into(), list_len.into()], &bitcode::STR_INIT)
.into_struct_value();
build_struct(
builder,
record_type,
vec![
(
env.ptr_int().const_int(0, false).into(),
"insert_zeroed_byte_index",
),
(zig_str_to_struct(env, zig_str).into(), "insert_str"),
(ctx.bool_type().const_int(1, false).into(), "insert_is_ok"),
(
ctx.i8_type().const_int(0, false).into(),
"insert_zeroed_problem",
),
],
)
.into()
},
|| {
// We do not have a valid utf8 byte sequence
build_struct(
builder,
record_type,
vec![
(byte_index.into(), "insert_byte_index"),
(empty_polymorphic_list(env), "insert_zeroed_str"),
(ctx.bool_type().const_int(0, false).into(), "insert_is_ok"),
(problem_code.into(), "insert_problem"),
],
)
.into()
},
BasicTypeEnum::StructType(record_type),
)
}
fn build_struct<'env, 'ctx>(
builder: &'env Builder<'ctx>,
struct_type: StructType<'ctx>,
values: Vec<(BasicValueEnum<'ctx>, &str)>,
) -> StructValue<'ctx> {
let mut val = struct_type.get_undef().into();
for (index, (value, name)) in values.iter().enumerate() {
val = builder
.build_insert_value(val, *value, index as u32, name)
.unwrap();
}
val.into_struct_value()
} }
/// Str.fromInt : Int -> Str /// Str.fromInt : Int -> Str

View file

@ -750,6 +750,12 @@ mod gen_num {
assert_evals_to!("Num.bitwiseXor 200 0", 200, i64); assert_evals_to!("Num.bitwiseXor 200 0", 200, i64);
} }
#[test]
fn bitwise_or() {
assert_evals_to!("Num.bitwiseOr 1 1", 1, i64);
assert_evals_to!("Num.bitwiseOr 1 2", 3, i64);
}
#[test] #[test]
fn lt_i64() { fn lt_i64() {
assert_evals_to!("1 < 2", true, bool); assert_evals_to!("1 < 2", true, bool);
@ -1343,4 +1349,29 @@ mod gen_num {
f64 f64
); );
} }
#[test]
fn shift_left_by() {
assert_evals_to!("Num.shiftLeftBy 0 0b0000_0001", 0b0000_0001, i64);
assert_evals_to!("Num.shiftLeftBy 1 0b0000_0001", 0b0000_0010, i64);
assert_evals_to!("Num.shiftLeftBy 2 0b0000_0011", 0b0000_1100, i64);
}
#[test]
#[ignore]
fn shift_right_by() {
// Sign Extended Right Shift
assert_evals_to!("Num.shiftRightBy 0 0b0100_0000i8", 0b0001_0000, i8);
assert_evals_to!("Num.shiftRightBy 1 0b1110_0000u8", 0b1111_0000u8 as i8, i8);
assert_evals_to!("Num.shiftRightBy 2 0b1100_0000u8", 0b1111_0000u8 as i8, i8);
}
#[test]
#[ignore]
fn shift_right_zf_by() {
// Logical Right Shift
assert_evals_to!("Num.shiftRightBy 1 0b1100_0000u8", 0b0011_0000, i64);
assert_evals_to!("Num.shiftRightBy 2 0b0000_0010u8", 0b0000_0001, i64);
assert_evals_to!("Num.shiftRightBy 3 0b0000_1100u8", 0b0000_0011, i64);
}
} }

View file

@ -816,4 +816,17 @@ mod gen_str {
fn str_from_float() { fn str_from_float() {
assert_evals_to!(r#"Str.fromFloat 3.14"#, RocStr::from("3.140000"), RocStr); assert_evals_to!(r#"Str.fromFloat 3.14"#, RocStr::from("3.140000"), RocStr);
} }
#[test]
fn str_to_bytes() {
assert_evals_to!(r#"Str.toBytes "hello""#, &[104, 101, 108, 108, 111], &[u8]);
assert_evals_to!(
r#"Str.toBytes "this is a long string""#,
&[
116, 104, 105, 115, 32, 105, 115, 32, 97, 32, 108, 111, 110, 103, 32, 115, 116,
114, 105, 110, 103
],
&[u8]
);
}
} }

View file

@ -358,6 +358,8 @@ struct ModuleCache<'a> {
external_specializations_requested: MutMap<ModuleId, ExternalSpecializations>, external_specializations_requested: MutMap<ModuleId, ExternalSpecializations>,
/// Various information /// Various information
imports: MutMap<ModuleId, MutSet<ModuleId>>,
top_level_thunks: MutMap<ModuleId, MutSet<Symbol>>,
documentation: MutMap<ModuleId, ModuleDocumentation>, documentation: MutMap<ModuleId, ModuleDocumentation>,
can_problems: MutMap<ModuleId, Vec<roc_problem::can::Problem>>, can_problems: MutMap<ModuleId, Vec<roc_problem::can::Problem>>,
type_problems: MutMap<ModuleId, Vec<solve::TypeError>>, type_problems: MutMap<ModuleId, Vec<solve::TypeError>>,
@ -544,11 +546,24 @@ fn start_phase<'a>(module_id: ModuleId, phase: Phase, state: &mut State<'a>) ->
ident_ids, ident_ids,
} = typechecked; } = typechecked;
let mut imported_module_thunks = MutSet::default();
if let Some(imports) = state.module_cache.imports.get(&module_id) {
for imported in imports.iter() {
imported_module_thunks.extend(
state.module_cache.top_level_thunks[imported]
.iter()
.copied(),
);
}
}
BuildTask::BuildPendingSpecializations { BuildTask::BuildPendingSpecializations {
layout_cache, layout_cache,
module_id, module_id,
module_timing, module_timing,
solved_subs, solved_subs,
imported_module_thunks,
decls, decls,
ident_ids, ident_ids,
exposed_to_host: state.exposed_to_host.clone(), exposed_to_host: state.exposed_to_host.clone(),
@ -957,6 +972,7 @@ enum BuildTask<'a> {
module_timing: ModuleTiming, module_timing: ModuleTiming,
layout_cache: LayoutCache<'a>, layout_cache: LayoutCache<'a>,
solved_subs: Solved<Subs>, solved_subs: Solved<Subs>,
imported_module_thunks: MutSet<Symbol>,
module_id: ModuleId, module_id: ModuleId,
ident_ids: IdentIds, ident_ids: IdentIds,
decls: Vec<Declaration>, decls: Vec<Declaration>,
@ -1662,6 +1678,18 @@ fn update<'a>(
.exposed_symbols_by_module .exposed_symbols_by_module
.insert(home, exposed_symbols); .insert(home, exposed_symbols);
state
.module_cache
.imports
.entry(header.module_id)
.or_default()
.extend(
header
.package_qualified_imported_modules
.iter()
.map(|x| *x.as_inner()),
);
work.extend(state.dependencies.add_module( work.extend(state.dependencies.add_module(
header.module_id, header.module_id,
&header.package_qualified_imported_modules, &header.package_qualified_imported_modules,
@ -1925,6 +1953,13 @@ fn update<'a>(
} }
} }
state
.module_cache
.top_level_thunks
.entry(module_id)
.or_default()
.extend(procs.module_thunks.iter().copied());
let found_specializations_module = FoundSpecializationsModule { let found_specializations_module = FoundSpecializationsModule {
layout_cache, layout_cache,
module_id, module_id,
@ -3747,6 +3782,7 @@ fn make_specializations<'a>(
fn build_pending_specializations<'a>( fn build_pending_specializations<'a>(
arena: &'a Bump, arena: &'a Bump,
solved_subs: Solved<Subs>, solved_subs: Solved<Subs>,
imported_module_thunks: MutSet<Symbol>,
home: ModuleId, home: ModuleId,
mut ident_ids: IdentIds, mut ident_ids: IdentIds,
decls: Vec<Declaration>, decls: Vec<Declaration>,
@ -3759,6 +3795,9 @@ fn build_pending_specializations<'a>(
let find_specializations_start = SystemTime::now(); let find_specializations_start = SystemTime::now();
let mut procs = Procs::default(); let mut procs = Procs::default();
debug_assert!(procs.imported_module_thunks.is_empty());
procs.imported_module_thunks = imported_module_thunks;
let mut mono_problems = std::vec::Vec::new(); let mut mono_problems = std::vec::Vec::new();
let mut subs = solved_subs.into_inner(); let mut subs = solved_subs.into_inner();
let mut mono_env = roc_mono::ir::Env { let mut mono_env = roc_mono::ir::Env {
@ -4040,10 +4079,12 @@ where
module_timing, module_timing,
layout_cache, layout_cache,
solved_subs, solved_subs,
imported_module_thunks,
exposed_to_host, exposed_to_host,
} => Ok(build_pending_specializations( } => Ok(build_pending_specializations(
arena, arena,
solved_subs, solved_subs,
imported_module_thunks,
module_id, module_id,
ident_ids, ident_ids,
decls, decls,

View file

@ -12,6 +12,7 @@ pub enum LowLevel {
StrCountGraphemes, StrCountGraphemes,
StrFromInt, StrFromInt,
StrFromUtf8, StrFromUtf8,
StrToBytes,
StrFromFloat, StrFromFloat,
ListLen, ListLen,
ListGetUnsafe, ListGetUnsafe,
@ -79,6 +80,11 @@ pub enum LowLevel {
NumAsin, NumAsin,
NumBitwiseAnd, NumBitwiseAnd,
NumBitwiseXor, NumBitwiseXor,
NumBitwiseOr,
NumShiftLeftBy,
NumShiftRightBy,
NumShiftRightZfBy,
NumIntCast,
Eq, Eq,
NotEq, NotEq,
And, And,

View file

@ -841,15 +841,21 @@ define_builtins! {
80 NUM_BINARY32: "Binary32" imported 80 NUM_BINARY32: "Binary32" imported
81 NUM_BITWISE_AND: "bitwiseAnd" 81 NUM_BITWISE_AND: "bitwiseAnd"
82 NUM_BITWISE_XOR: "bitwiseXor" 82 NUM_BITWISE_XOR: "bitwiseXor"
83 NUM_SUB_WRAP: "subWrap" 83 NUM_BITWISE_OR: "bitwiseOr"
84 NUM_SUB_CHECKED: "subChecked" 84 NUM_SHIFT_LEFT: "shiftLeftBy"
85 NUM_MUL_WRAP: "mulWrap" 85 NUM_SHIFT_RIGHT: "shiftRightBy"
86 NUM_MUL_CHECKED: "mulChecked" 86 NUM_SHIFT_RIGHT_ZERO_FILL: "shiftRightZfBy"
87 NUM_INT: "Int" imported 87 NUM_SUB_WRAP: "subWrap"
88 NUM_FLOAT: "Float" imported 88 NUM_SUB_CHECKED: "subChecked"
89 NUM_AT_NATURAL: "@Natural" 89 NUM_MUL_WRAP: "mulWrap"
90 NUM_NATURAL: "Natural" imported 90 NUM_MUL_CHECKED: "mulChecked"
91 NUM_NAT: "Nat" imported 91 NUM_INT: "Int" imported
92 NUM_FLOAT: "Float" imported
93 NUM_AT_NATURAL: "@Natural"
94 NUM_NATURAL: "Natural" imported
95 NUM_NAT: "Nat" imported
96 NUM_INT_CAST: "intCast"
} }
2 BOOL: "Bool" => { 2 BOOL: "Bool" => {
0 BOOL_BOOL: "Bool" imported // the Bool.Bool type alias 0 BOOL_BOOL: "Bool" imported // the Bool.Bool type alias
@ -876,6 +882,7 @@ define_builtins! {
12 STR_FROM_UTF8: "fromUtf8" 12 STR_FROM_UTF8: "fromUtf8"
13 STR_UT8_PROBLEM: "Utf8Problem" // the Utf8Problem type alias 13 STR_UT8_PROBLEM: "Utf8Problem" // the Utf8Problem type alias
14 STR_UT8_BYTE_PROBLEM: "Utf8ByteProblem" // the Utf8ByteProblem type alias 14 STR_UT8_BYTE_PROBLEM: "Utf8ByteProblem" // the Utf8ByteProblem type alias
15 STR_TO_BYTES: "toBytes"
} }
4 LIST: "List" => { 4 LIST: "List" => {
0 LIST_LIST: "List" imported // the List.List type alias 0 LIST_LIST: "List" imported // the List.List type alias

View file

@ -373,6 +373,14 @@ impl<'a> BorrowInfState<'a> {
self.own_var(z); self.own_var(z);
// if the function exects an owned argument (ps), the argument must be owned (args) // if the function exects an owned argument (ps), the argument must be owned (args)
debug_assert_eq!(
arguments.len(),
ps.len(),
"{:?} has {} parameters, but was applied to {} arguments",
name,
ps.len(),
arguments.len()
);
self.own_args_using_params(arguments, ps); self.own_args_using_params(arguments, ps);
} }
None => { None => {
@ -658,14 +666,17 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
And | Or | NumAdd | NumAddWrap | NumAddChecked | NumSub | NumSubWrap | NumSubChecked And | Or | NumAdd | NumAddWrap | NumAddChecked | NumSub | NumSubWrap | NumSubChecked
| NumMul | NumMulWrap | NumMulChecked | NumGt | NumGte | NumLt | NumLte | NumCompare | NumMul | NumMulWrap | NumMulChecked | NumGt | NumGte | NumLt | NumLte | NumCompare
| NumDivUnchecked | NumRemUnchecked | NumPow | NumPowInt | NumBitwiseAnd | NumDivUnchecked | NumRemUnchecked | NumPow | NumPowInt | NumBitwiseAnd
| NumBitwiseXor => arena.alloc_slice_copy(&[irrelevant, irrelevant]), | NumBitwiseXor | NumBitwiseOr | NumShiftLeftBy | NumShiftRightBy | NumShiftRightZfBy => {
arena.alloc_slice_copy(&[irrelevant, irrelevant])
}
NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumCeiling | NumFloor NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumCeiling | NumFloor
| NumToFloat | Not | NumIsFinite | NumAtan | NumAcos | NumAsin => { | NumToFloat | Not | NumIsFinite | NumAtan | NumAcos | NumAsin | NumIntCast => {
arena.alloc_slice_copy(&[irrelevant]) arena.alloc_slice_copy(&[irrelevant])
} }
StrStartsWith | StrEndsWith => arena.alloc_slice_copy(&[owned, borrowed]), StrStartsWith | StrEndsWith => arena.alloc_slice_copy(&[owned, borrowed]),
StrFromUtf8 => arena.alloc_slice_copy(&[owned]), StrFromUtf8 => arena.alloc_slice_copy(&[owned]),
StrToBytes => arena.alloc_slice_copy(&[owned]),
StrFromInt | StrFromFloat => arena.alloc_slice_copy(&[irrelevant]), StrFromInt | StrFromFloat => arena.alloc_slice_copy(&[irrelevant]),
Hash => arena.alloc_slice_copy(&[borrowed, irrelevant]), Hash => arena.alloc_slice_copy(&[borrowed, irrelevant]),
DictSize => arena.alloc_slice_copy(&[borrowed]), DictSize => arena.alloc_slice_copy(&[borrowed]),

View file

@ -273,6 +273,7 @@ impl ExternalSpecializations {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Procs<'a> { pub struct Procs<'a> {
pub partial_procs: MutMap<Symbol, PartialProc<'a>>, pub partial_procs: MutMap<Symbol, PartialProc<'a>>,
pub imported_module_thunks: MutSet<Symbol>,
pub module_thunks: MutSet<Symbol>, pub module_thunks: MutSet<Symbol>,
pub pending_specializations: Option<MutMap<Symbol, MutMap<Layout<'a>, PendingSpecialization>>>, pub pending_specializations: Option<MutMap<Symbol, MutMap<Layout<'a>, PendingSpecialization>>>,
pub specialized: MutMap<(Symbol, Layout<'a>), InProgressProc<'a>>, pub specialized: MutMap<(Symbol, Layout<'a>), InProgressProc<'a>>,
@ -285,6 +286,7 @@ impl<'a> Default for Procs<'a> {
fn default() -> Self { fn default() -> Self {
Self { Self {
partial_procs: MutMap::default(), partial_procs: MutMap::default(),
imported_module_thunks: MutSet::default(),
module_thunks: MutSet::default(), module_thunks: MutSet::default(),
pending_specializations: Some(MutMap::default()), pending_specializations: Some(MutMap::default()),
specialized: MutMap::default(), specialized: MutMap::default(),
@ -302,39 +304,6 @@ pub enum InProgressProc<'a> {
} }
impl<'a> Procs<'a> { impl<'a> Procs<'a> {
/// Absorb the contents of another Procs into this one.
pub fn absorb(&mut self, mut other: Procs<'a>) {
debug_assert!(self.pending_specializations.is_some());
debug_assert!(other.pending_specializations.is_some());
match self.pending_specializations {
Some(ref mut pending_specializations) => {
for (k, v) in other.pending_specializations.unwrap().drain() {
pending_specializations.insert(k, v);
}
}
None => {
unreachable!();
}
}
for (k, v) in other.partial_procs.drain() {
self.partial_procs.insert(k, v);
}
for (k, v) in other.specialized.drain() {
self.specialized.insert(k, v);
}
for (k, v) in other.runtime_errors.drain() {
self.runtime_errors.insert(k, v);
}
for symbol in other.module_thunks.drain() {
self.module_thunks.insert(symbol);
}
}
pub fn get_specialized_procs_without_rc( pub fn get_specialized_procs_without_rc(
self, self,
arena: &'a Bump, arena: &'a Bump,
@ -5751,8 +5720,18 @@ fn call_by_pointer<'a>(
let is_specialized = procs.specialized.keys().any(|(s, _)| *s == symbol); let is_specialized = procs.specialized.keys().any(|(s, _)| *s == symbol);
if env.is_imported_symbol(symbol) || procs.partial_procs.contains_key(&symbol) || is_specialized if env.is_imported_symbol(symbol) || procs.partial_procs.contains_key(&symbol) || is_specialized
{ {
// anything that is not a thunk can be called by-value in the wrapper
// (the above condition guarantees we're dealing with a top-level symbol)
//
// But thunks cannot be called by-value, since they are not really functions to all parts
// of the system (notably RC insertion). So we still call those by-pointer.
// Luckily such values were top-level originally (in the user code), and can therefore
// not be closures
let is_thunk =
procs.module_thunks.contains(&symbol) || procs.imported_module_thunks.contains(&symbol);
match layout { match layout {
Layout::FunctionPointer(arg_layouts, ret_layout) => { Layout::FunctionPointer(arg_layouts, ret_layout) if !is_thunk => {
if arg_layouts.iter().any(|l| l.contains_refcounted()) { if arg_layouts.iter().any(|l| l.contains_refcounted()) {
let name = env.unique_symbol(); let name = env.unique_symbol();
let mut args = Vec::with_capacity_in(arg_layouts.len(), env.arena); let mut args = Vec::with_capacity_in(arg_layouts.len(), env.arena);
@ -5766,6 +5745,7 @@ fn call_by_pointer<'a>(
let args = args.into_bump_slice(); let args = args.into_bump_slice();
let call_symbol = env.unique_symbol(); let call_symbol = env.unique_symbol();
debug_assert_eq!(arg_layouts.len(), arg_symbols.len());
let call_type = CallType::ByName { let call_type = CallType::ByName {
name: symbol, name: symbol,
full_layout: layout.clone(), full_layout: layout.clone(),
@ -5804,6 +5784,63 @@ fn call_by_pointer<'a>(
Expr::FunctionPointer(symbol, layout) Expr::FunctionPointer(symbol, layout)
} }
} }
Layout::FunctionPointer(arg_layouts, ret_layout) => {
if arg_layouts.iter().any(|l| l.contains_refcounted()) {
let name = env.unique_symbol();
let mut args = Vec::with_capacity_in(arg_layouts.len(), env.arena);
let mut arg_symbols = Vec::with_capacity_in(arg_layouts.len(), env.arena);
for layout in arg_layouts {
let symbol = env.unique_symbol();
args.push((layout.clone(), symbol));
arg_symbols.push(symbol);
}
let args = args.into_bump_slice();
let call_symbol = env.unique_symbol();
let fpointer_symbol = env.unique_symbol();
debug_assert_eq!(arg_layouts.len(), arg_symbols.len());
let call_type = CallType::ByPointer {
name: fpointer_symbol,
full_layout: layout.clone(),
ret_layout: ret_layout.clone(),
arg_layouts,
};
let call = Call {
call_type,
arguments: arg_symbols.into_bump_slice(),
};
let expr = Expr::Call(call);
let mut body = Stmt::Ret(call_symbol);
body = Stmt::Let(call_symbol, expr, ret_layout.clone(), env.arena.alloc(body));
let expr = Expr::FunctionPointer(symbol, layout.clone());
body = Stmt::Let(fpointer_symbol, expr, layout.clone(), env.arena.alloc(body));
let closure_data_layout = None;
let proc = Proc {
name,
args,
body,
closure_data_layout,
ret_layout: ret_layout.clone(),
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: true,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
};
procs
.specialized
.insert((name, layout.clone()), InProgressProc::Done(proc));
Expr::FunctionPointer(name, layout)
} else {
// if none of the arguments is refcounted, then owning the arguments has no
// meaning
Expr::FunctionPointer(symbol, layout)
}
}
_ => { _ => {
// e.g. Num.maxInt or other constants // e.g. Num.maxInt or other constants
Expr::FunctionPointer(symbol, layout) Expr::FunctionPointer(symbol, layout)

View file

@ -603,10 +603,7 @@ fn to_if_report<'a>(
start_row, start_row,
start_col, start_col,
alloc.concat(vec![ alloc.concat(vec![
alloc.reflow(r"I just saw a pattern, so I was expecting to see a "), alloc.reflow(r"I was expecting to see a expression next")
alloc.parser_suggestion("->"),
alloc.reflow(" next."),
alloc.reflow(r"I was expecting to see a expression next"),
]), ]),
), ),
} }

View file

@ -4315,4 +4315,26 @@ mod solve_expr {
"Str", "Str",
); );
} }
#[test]
fn int_type_let_polymorphism() {
infer_eq_without_problem(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
x = 4
f : U8 -> U32
f = \z -> Num.intCast z
y = f x
main =
x
"#
),
"Num *",
);
}
} }

View file

@ -0,0 +1,140 @@
interface Base64 exposes [ fromBytes ] imports [ Bytes.Decode ]
Decoder a : Bytes.Decode.Decoder a
fromBytes : List U8 -> Result Str Bytes.Decode.DecodeError
fromBytes = \bytes ->
Bytes.Decode.decode bytes (decodeBase64 (List.len bytes))
decodeBase64 : Nat -> Bytes.Decode.Decoder Str
decodeBase64 = \width -> Bytes.Decode.loop loopHelp { remaining: width, string: "" }
loopHelp : { remaining : Nat, string : Str } -> Decoder (Bytes.Decode.Step { remaining : Nat, string : Str } Str)
loopHelp = \{ remaining, string } ->
if remaining >= 3 then
Bytes.Decode.map3
Bytes.Decode.u8
Bytes.Decode.u8
Bytes.Decode.u8
\x, y, z ->
a : U32
a = Num.intCast x
b : U32
b = Num.intCast y
c : U32
c = Num.intCast z
combined = Num.bitwiseOr (Num.bitwiseOr (Num.shiftLeftBy 16 a) (Num.shiftLeftBy 8 b)) c
Loop
{
remaining: remaining - 3,
string: Str.concat string (bitsToChars combined 0)
}
else if remaining == 0 then
Bytes.Decode.succeed (Done string)
else if remaining == 2 then
Bytes.Decode.map2
Bytes.Decode.u8
Bytes.Decode.u8
\x, y ->
a : U32
a = Num.intCast x
b : U32
b = Num.intCast y
combined = Num.bitwiseOr (Num.shiftLeftBy 16 a) (Num.shiftLeftBy 8 b)
Done (Str.concat string (bitsToChars combined 1))
else
# remaining = 1
Bytes.Decode.map
Bytes.Decode.u8
\x ->
a : U32
a = Num.intCast x
Done (Str.concat string (bitsToChars (Num.shiftLeftBy 16 a) 2))
bitsToChars : U32, Int * -> Str
bitsToChars = \bits, missing ->
when Str.fromUtf8 (bitsToCharsHelp bits missing) is
Ok str -> str
Err _ -> ""
# Mask that can be used to get the lowest 6 bits of a binary number
lowest6BitsMask : Int *
lowest6BitsMask = 63
bitsToCharsHelp : U32, Int * -> List U8
bitsToCharsHelp = \bits, missing ->
# The input is 24 bits, which we have to partition into 4 6-bit segments. We achieve this by
# shifting to the right by (a multiple of) 6 to remove unwanted bits on the right, then `Num.bitwiseAnd`
# with `0b111111` (which is 2^6 - 1 or 63) (so, 6 1s) to remove unwanted bits on the left.
# any 6-bit number is a valid base64 digit, so this is actually safe
p =
Num.shiftRightZfBy 18 bits
|> Num.intCast
|> unsafeToChar
q =
Num.bitwiseAnd (Num.shiftRightZfBy 12 bits) lowest6BitsMask
|> Num.intCast
|> unsafeToChar
r =
Num.bitwiseAnd (Num.shiftRightZfBy 6 bits) lowest6BitsMask
|> Num.intCast
|> unsafeToChar
s =
Num.bitwiseAnd bits lowest6BitsMask
|> Num.intCast
|> unsafeToChar
equals : U8
equals = 61
when missing is
0 ->
[ p, q, r, s ]
1 ->
[ p, q, r, equals ]
2 ->
[ p, q, equals , equals ]
_ ->
# unreachable
[]
# Base64 index to character/digit
unsafeToChar : U8 -> U8
unsafeToChar = \n ->
if n <= 25 then
# uppercase characters
65 + n
else if n <= 51 then
# lowercase characters
97 + (n - 26)
else if n <= 61 then
# digit characters
48 + (n - 52)
else
# special cases
when n is
62 ->
# '+'
43
63 ->
# '/'
47
_ ->
# anything else is invalid '\u{0000}'
0

View file

@ -0,0 +1,106 @@
interface Bytes.Decode exposes [ Decoder, decode, map, map2, u8, loop, Step, succeed, DecodeError, after, map3 ] imports []
State : { bytes: List U8, cursor : Nat }
DecodeError : [ OutOfBytes ]
Decoder a : [ @Decoder (State -> [Good State a, Bad DecodeError]) ]
decode : List U8, Decoder a -> Result a DecodeError
decode = \bytes, @Decoder decoder ->
when decoder { bytes, cursor: 0 } is
Good _ value ->
Ok value
Bad e ->
Err e
succeed : a -> Decoder a
succeed = \value -> @Decoder \state -> Good state value
map : Decoder a, (a -> b) -> Decoder b
map = \@Decoder decoder, transform ->
@Decoder \state ->
when decoder state is
Good state1 value ->
Good state1 (transform value)
Bad e ->
Bad e
map2 : Decoder a, Decoder b, (a, b -> c) -> Decoder c
map2 = \@Decoder decoder1, @Decoder decoder2, transform ->
@Decoder \state1 ->
when decoder1 state1 is
Good state2 a ->
when decoder2 state2 is
Good state3 b ->
Good state3 (transform a b)
Bad e ->
Bad e
Bad e ->
Bad e
map3 : Decoder a, Decoder b, Decoder c, (a, b, c -> d) -> Decoder d
map3 = \@Decoder decoder1, @Decoder decoder2, @Decoder decoder3, transform ->
@Decoder \state1 ->
when decoder1 state1 is
Good state2 a ->
when decoder2 state2 is
Good state3 b ->
when decoder3 state3 is
Good state4 c ->
Good state4 (transform a b c)
Bad e ->
Bad e
Bad e ->
Bad e
Bad e ->
Bad e
after : Decoder a, (a -> Decoder b) -> Decoder b
after = \@Decoder decoder, transform ->
@Decoder \state ->
when decoder state is
Good state1 value ->
(@Decoder decoder1) = transform value
decoder1 state1
Bad e ->
Bad e
u8 : Decoder U8
u8 = @Decoder \state ->
when List.get state.bytes state.cursor is
Ok b ->
Good { state & cursor: state.cursor + 1 } b
Err _ ->
Bad OutOfBytes
Step state b : [ Loop state, Done b ]
loop : (state -> Decoder (Step state a)), state -> Decoder a
loop = \stepper, initial ->
@Decoder \state ->
loopHelp stepper initial state
loopHelp = \stepper, accum, state ->
(@Decoder stepper1) = stepper accum
when stepper1 state is
Good newState (Done value) ->
Good newState value
Good newState (Loop newAccum) ->
loopHelp stepper newAccum newState
Bad e ->
Bad e

View file

@ -1,4 +1,4 @@
app "astar-tests" app "test-astar"
packages { base: "platform" } packages { base: "platform" }
imports [base.Task, AStar] imports [base.Task, AStar]
provides [ main ] to base provides [ main ] to base

View file

@ -0,0 +1,16 @@
app "test-base64"
packages { base: "platform" }
imports [base.Task, Base64 ]
provides [ main ] to base
IO a : Task.Task a []
main : IO {}
main =
when Base64.fromBytes (Str.toBytes "Hello World") is
Ok str ->
Task.putLine str
Err _ ->
Task.putLine "sadness"