Merge branch 'roc-lang:main' into list-update

This commit is contained in:
KilianVounckx 2023-06-01 11:16:33 +02:00 committed by GitHub
commit 8b85f966fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1042 additions and 424 deletions

194
crates/compiler/DESIGN.md Normal file
View file

@ -0,0 +1,194 @@
# Compiler Design
The current Roc compiler is designed as a pipelining compiler parallelizable
across Roc modules.
Roc's compilation pipeline consists of a few major components, which form the
table of contents for this document.
<!-- START doctoc generated TOC please keep comment here to allow auto update -->
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->
- [Parsing](#parsing)
- [Canonicalization](#canonicalization)
- [Symbol Resolution](#symbol-resolution)
- [Type-alias normalization](#type-alias-normalization)
- [Closure naming](#closure-naming)
- [Constraint Generation](#constraint-generation)
- [(Mutually-)recursive definitions](#mutually-recursive-definitions)
- [Type Solving](#type-solving)
- [Unification](#unification)
- [Type Inference](#type-inference)
- [Recursive Types](#recursive-types)
- [Lambda Sets](#lambda-sets)
- [Ability Collection](#ability-collection)
- [Ability Specialization](#ability-specialization)
- [Ability Derivation](#ability-derivation)
- [Exhaustiveness Checking](#exhaustiveness-checking)
- [Debugging](#debugging)
- [IR Generation](#ir-generation)
- [Memory Layouts](#memory-layouts)
- [Compiling Calls](#compiling-calls)
- [Decision Trees](#decision-trees)
- [Tail-call Optimization](#tail-call-optimization)
- [Reference-count insertion](#reference-count-insertion)
- [Reusing Memory Allocations](#reusing-memory-allocations)
- [Debugging](#debugging-1)
- [LLVM Code Generator](#llvm-code-generator)
- [Morphic Analysis](#morphic-analysis)
- [C ABI](#c-abi)
- [Test Harness](#test-harness)
- [Debugging](#debugging-2)
- [WASM Code Generator](#wasm-code-generator)
- [WASM Interpreter](#wasm-interpreter)
- [Debugging](#debugging-3)
- [Dev Code Generator](#dev-code-generator)
- [Debugging](#debugging-4)
- [Builtins](#builtins)
- [Compiler Driver](#compiler-driver)
- [Caching types](#caching-types)
- [Repl](#repl)
- [`test` and `dbg`](#test-and-dbg)
- [Formatter](#formatter)
- [Glue](#glue)
- [Active areas of research / help wanted](#active-areas-of-research--help-wanted)
<!-- END doctoc generated TOC please keep comment here to allow auto update -->
## Parsing
Roc's parsers are designed as [combinators](https://en.wikipedia.org/wiki/Parser_combinator).
A list of Roc's parse AST and combinators can be found in [the root parse
file](./parse/src/parser.rs).
Combinators enable parsing to compose as functions would - for example, the
`one_of` combinator supports attempting multiple parsing strategies, and
succeeding on the first one; the `and_then` combinator chains two parsers
together, failing if either parser in the sequence fails.
Since Roc is an indentation-sensitive language, parsing must be cognizant and
deligent about handling indentation and de-indentation levels. Most parsing
functions take a `min_indent` parameter that specifies the minimum indentation
of the scope an expression should be parsed in. Generally, failing to reach
`min_indent` indicates that an expression has ended (but perhaps too early).
## Canonicalization
After parsing a Roc program into an AST, the AST is transformed into a [canonical
form](./can/src/expr.rs) AST. This may seem a bit redundant - why build another
tree, when we already have the AST? Canonicalization performs a few analyses
to catch user errors, and sets up the state necessary to solve the types in a
program. Among other things, canonicalization
- Uniquely identifies names (think variable and function names). Along the way,
canonicalization builds a graph of all variables' references, and catches
unused definitions, undefined definitions, and shadowed definitions.
- Resolves type signatures, including aliases, into a form suitable for type
solving.
- Determines the order definitions are used in, if they are defined
out-of-order.
- Eliminates syntax sugar (for example, renaming `+` to the function call `add`
and converting backpassing to function calls).
- Collects declared abilities, and ability implementations defined for opaque
types. Derived abilities for opaque types are elaborated during
canonicalization.
### Symbol Resolution
Identifiers, like variable names, are resolved to [Symbol](./module/src/symbol.rs)s.
Currently, a symbol is a 64-bit value with
- the bottom 32 bits defining the [ModuleId](./module/src/ident.rs) the symbol
is defined in
- the top 32 bits defining the [IdentId](./module/src/ident.rs) of the symbol
in the module
A symbol is unique per identifier name and the scope
that the identifier has been declared in. Symbols are how the rest of the
compiler refers to value definitions - since the unique scope and identifier
name is disambiguated when symbols are created, referencing symbols requires no
further name resolution.
As symbols are constructed, canonicalization also keeps track of all references
to a given symbol. This simplifies catching unused definitions, undefined
definitions, and shadowing, to an index into an array.
### Type-alias normalization
### Closure naming
## Constraint Generation
### (Mutually-)recursive definitions
## Type Solving
### Unification
### Type Inference
### Recursive Types
### Lambda Sets
### Ability Collection
### Ability Specialization
### Ability Derivation
### Exhaustiveness Checking
### Debugging
## IR Generation
### Memory Layouts
### Compiling Calls
### Decision Trees
### Tail-call Optimization
### Reference-count insertion
### Reusing Memory Allocations
### Debugging
## LLVM Code Generator
### Morphic Analysis
### C ABI
### Test Harness
### Debugging
## WASM Code Generator
### WASM Interpreter
### Debugging
## Dev Code Generator
### Debugging
## Builtins
## Compiler Driver
### Caching types
## Repl
## `test` and `dbg`
## Formatter
## Glue
## Active areas of research / help wanted

View file

@ -1,147 +1,8 @@
# The Roc Compiler
Here's how the compiler is laid out.
## Parsing
The main goal of parsing is to take a plain old String (such as the contents a .roc source file read from the filesystem) and translate that String into an `Expr` value.
`Expr` is an `enum` defined in the `expr` module. An `Expr` represents a Roc expression.
For example, parsing would translate this string...
"1 + 2"
...into this `Expr` value:
BinOp(Int(1), Plus, Int(2))
> Technically it would be `Box::new(Int(1))` and `Box::new(Int(2))`, but that's beside the point for now.
This `Expr` representation of the expression is useful for things like:
- Checking that all variables are declared before they're used
- Type checking
> As of this writing, the compiler doesn't do any of those things yet. They'll be added later!
Since the parser is only concerned with translating String values into Expr values, it will happily translate syntactically valid strings into expressions that won't work at runtime.
For example, parsing will translate this string:
not "foo", "bar"
...into this `Expr`:
CallByName("not", vec!["foo", "bar"])
Now we may know that `not` takes a `Bool` and returns another `Bool`, but the parser doesn't know that.
The parser only knows how to translate a `String` into an `Expr`; it's the job of other parts of the compiler to figure out if `Expr` values have problems like type mismatches and non-exhaustive patterns.
That said, the parser can still run into syntax errors. This won't parse:
if then 5 then else then
This is gibberish to the parser, so it will produce an error rather than an `Expr`.
Roc's parser is implemented using the [`marwes/combine`](http://github.com/marwes/combine-language/) crate.
## Evaluating
One of the useful things we can do with an `Expr` is to evaluate it.
The process of evaluation is basically to transform an `Expr` into the simplest `Expr` we can that's still equivalent to the original.
For example, let's say we had this code:
"1 + 8 - 3"
The parser will translate this into the following `Expr`:
BinOp(
Int(1),
Plus,
BinOp(Int(8), Minus, Int(3))
)
The `eval` function will take this `Expr` and translate it into this much simpler `Expr`:
Int(6)
At this point it's become so simple that we can display it to the end user as the number `6`. So running `parse` and then `eval` on the original Roc string of `1 + 8 - 3` will result in displaying `6` as the final output.
> The `expr` module includes an `impl fmt::Display for Expr` that takes care of translating `Int(6)` into `6`, `Char('x')` as `'x'`, and so on.
`eval` accomplishes this by doing a `match` on an `Expr` and resolving every operation it encounters. For example, when it first sees this:
BinOp(
Int(1),
Plus,
BinOp(Int(8), Minus, Int(3))
)
The first thing it does is to call `eval` on the right `Expr` values on either side of the `Plus`. That results in:
1. Calling `eval` on `Int(1)`, which returns `Int(1)` since it can't be reduced any further.
2. Calling `eval` on `BinOp(Int(8), Minus, Int(3))`, which in fact can be reduced further.
Since the second call to `eval` will match on another `BinOp`, it's once again going to recursively call `eval` on both of its `Expr` values. Since those are both `Int` values, though, their `eval` calls will return them right away without doing anything else.
Now that it's evaluated the expressions on either side of the `Minus`, `eval` will look at the particular operator being applied to those expressions (in this case, a minus operator) and check to see if the expressions it was given work with that operation.
> Remember, this `Expr` value potentially came directly from the parser. `eval` can't be sure any type checking has been done on it!
If `eval` detects a non-numeric `Expr` value (that is, the `Expr` is not `Int` or `Frac`) on either side of the `Minus`, then it will immediately give an error and halt the evaluation. This sort of runtime type error is common to dynamic languages, and you can think of `eval` as being a dynamic evaluation of Roc code that hasn't necessarily been type-checked.
Assuming there's no type problem, `eval` can go ahead and run the Rust code of `8 - 3` and store the result in an `Int` expr.
That concludes our original recursive call to `eval`, after which point we'll be evaluating this expression:
BinOp(
Int(1),
Plus,
Int(5)
)
This will work the same way as `Minus` did, and will reduce down to `Int(6)`.
## Optimization philosophy
Focus on optimizations which are only safe in the absence of side effects, and leave the rest to LLVM.
This focus may lead to some optimizations becoming transitively in scope. For example, some deforestation
examples in the MSR paper benefit from multiple rounds of interleaved deforestation, beta-reduction, and inlining.
To get those benefits, we'd have to do some inlining and beta-reduction that we could otherwise leave to LLVM's
inlining and constant propagation/folding.
Even if we're doing those things, it may still make sense to have LLVM do a pass for them as well, since
early LLVM optimization passes may unlock later opportunities for inlining and constant propagation/folding.
## Inlining
If a function is called exactly once (it's a helper function), presumably we always want to inline those.
If a function is "small enough" it's probably worth inlining too.
## Fusion
<https://www.microsoft.com/en-us/research/wp-content/uploads/2016/07/deforestation-short-cut.pdf>
Basic approach:
Do list stuff using `build` passing Cons Nil (like a cons list) and then do foldr/build substitution/reduction.
Afterwards, we can do a separate pass to flatten nested Cons structures into properly initialized RRBTs.
This way we get both deforestation and efficient RRBT construction. Should work for the other collection types too.
It looks like we need to do some amount of inlining and beta reductions on the Roc side, rather than
leaving all of those to LLVM.
Advanced approach:
Express operations like map and filter in terms of toStream and fromStream, to unlock more deforestation.
More info on here:
<https://wiki.haskell.org/GHC_optimisations#Fusion>
For an overview of the design and architecture of the compiler, see
[DESIGN.md](./DESIGN.md). If you want to dive into the
implementation or get some tips on debugging the compiler, see below
## Getting started with the code

View file

@ -17,7 +17,7 @@ pub fn build(b: *Builder) void {
// Tests
var main_tests = b.addTest(main_path);
main_tests.setBuildMode(mode);
main_tests.linkSystemLibrary("c");
main_tests.linkLibC();
const test_step = b.step("test", "Run tests");
test_step.dependOn(&main_tests.step);

View file

@ -0,0 +1,442 @@
const std = @import("std");
const builtin = @import("builtin");
const math = std.math;
// Eventually, we need to statically ingest compiler-rt and get it working with the surgical linker, then these should not be needed anymore.
// Until then, we are manually ingesting used parts of compiler-rt here.
//
// Taken from
// https://github.com/ziglang/zig/tree/4976b58ab16069f8d3267b69ed030f29685c1abe/lib/compiler_rt/
// Thank you Zig Contributors!
// Libcalls that involve u128 on Windows x86-64 are expected by LLVM to use the
// calling convention of @Vector(2, u64), rather than what's standard.
pub const want_windows_v2u64_abi = builtin.os.tag == .windows and builtin.cpu.arch == .x86_64 and @import("builtin").object_format != .c;
const v2u64 = @Vector(2, u64);
// Export it as weak incase it is already linked in by something else.
comptime {
@export(__muloti4, .{ .name = "__muloti4", .linkage = .Weak });
if (want_windows_v2u64_abi) {
@export(__divti3_windows_x86_64, .{ .name = "__divti3", .linkage = .Weak });
@export(__modti3_windows_x86_64, .{ .name = "__modti3", .linkage = .Weak });
@export(__umodti3_windows_x86_64, .{ .name = "__umodti3", .linkage = .Weak });
@export(__udivti3_windows_x86_64, .{ .name = "__udivti3", .linkage = .Weak });
@export(__fixdfti_windows_x86_64, .{ .name = "__fixdfti", .linkage = .Weak });
@export(__fixsfti_windows_x86_64, .{ .name = "__fixsfti", .linkage = .Weak });
@export(__fixunsdfti_windows_x86_64, .{ .name = "__fixunsdfti", .linkage = .Weak });
@export(__fixunssfti_windows_x86_64, .{ .name = "__fixunssfti", .linkage = .Weak });
} else {
@export(__divti3, .{ .name = "__divti3", .linkage = .Weak });
@export(__modti3, .{ .name = "__modti3", .linkage = .Weak });
@export(__umodti3, .{ .name = "__umodti3", .linkage = .Weak });
@export(__udivti3, .{ .name = "__udivti3", .linkage = .Weak });
@export(__fixdfti, .{ .name = "__fixdfti", .linkage = .Weak });
@export(__fixsfti, .{ .name = "__fixsfti", .linkage = .Weak });
@export(__fixunsdfti, .{ .name = "__fixunsdfti", .linkage = .Weak });
@export(__fixunssfti, .{ .name = "__fixunssfti", .linkage = .Weak });
}
}
pub fn __muloti4(a: i128, b: i128, overflow: *c_int) callconv(.C) i128 {
if (2 * @bitSizeOf(i128) <= @bitSizeOf(usize)) {
return muloXi4_genericFast(i128, a, b, overflow);
} else {
return muloXi4_genericSmall(i128, a, b, overflow);
}
}
pub fn __divti3(a: i128, b: i128) callconv(.C) i128 {
return div(a, b);
}
fn __divti3_windows_x86_64(a: v2u64, b: v2u64) callconv(.C) v2u64 {
return @bitCast(v2u64, div(@bitCast(i128, a), @bitCast(i128, b)));
}
inline fn div(a: i128, b: i128) i128 {
const s_a = a >> (128 - 1);
const s_b = b >> (128 - 1);
const an = (a ^ s_a) -% s_a;
const bn = (b ^ s_b) -% s_b;
const r = udivmod(u128, @bitCast(u128, an), @bitCast(u128, bn), null);
const s = s_a ^ s_b;
return (@bitCast(i128, r) ^ s) -% s;
}
pub fn __udivti3(a: u128, b: u128) callconv(.C) u128 {
return udivmod(u128, a, b, null);
}
fn __udivti3_windows_x86_64(a: v2u64, b: v2u64) callconv(.C) v2u64 {
return @bitCast(v2u64, udivmod(u128, @bitCast(u128, a), @bitCast(u128, b), null));
}
pub fn __umodti3(a: u128, b: u128) callconv(.C) u128 {
var r: u128 = undefined;
_ = udivmod(u128, a, b, &r);
return r;
}
fn __umodti3_windows_x86_64(a: v2u64, b: v2u64) callconv(.C) v2u64 {
var r: u128 = undefined;
_ = udivmod(u128, @bitCast(u128, a), @bitCast(u128, b), &r);
return @bitCast(v2u64, r);
}
pub fn __modti3(a: i128, b: i128) callconv(.C) i128 {
return mod(a, b);
}
fn __modti3_windows_x86_64(a: v2u64, b: v2u64) callconv(.C) v2u64 {
return @bitCast(v2u64, mod(@bitCast(i128, a), @bitCast(i128, b)));
}
inline fn mod(a: i128, b: i128) i128 {
const s_a = a >> (128 - 1); // s = a < 0 ? -1 : 0
const s_b = b >> (128 - 1); // s = b < 0 ? -1 : 0
const an = (a ^ s_a) -% s_a; // negate if s == -1
const bn = (b ^ s_b) -% s_b; // negate if s == -1
var r: u128 = undefined;
_ = udivmod(u128, @bitCast(u128, an), @bitCast(u128, bn), &r);
return (@bitCast(i128, r) ^ s_a) -% s_a; // negate if s == -1
}
pub fn __fixdfti(a: f64) callconv(.C) i128 {
return floatToInt(i128, a);
}
fn __fixdfti_windows_x86_64(a: f64) callconv(.C) v2u64 {
return @bitCast(v2u64, floatToInt(i128, a));
}
pub fn __fixsfti(a: f32) callconv(.C) i128 {
return floatToInt(i128, a);
}
fn __fixsfti_windows_x86_64(a: f32) callconv(.C) v2u64 {
return @bitCast(v2u64, floatToInt(i128, a));
}
pub fn __fixunsdfti(a: f64) callconv(.C) u128 {
return floatToInt(u128, a);
}
fn __fixunsdfti_windows_x86_64(a: f64) callconv(.C) v2u64 {
return @bitCast(v2u64, floatToInt(u128, a));
}
pub fn __fixunssfti(a: f32) callconv(.C) u128 {
return floatToInt(u128, a);
}
fn __fixunssfti_windows_x86_64(a: f32) callconv(.C) v2u64 {
return @bitCast(v2u64, floatToInt(u128, a));
}
// mulo - multiplication overflow
// * return a*%b.
// * return if a*b overflows => 1 else => 0
// - muloXi4_genericSmall as default
// - muloXi4_genericFast for 2*bitsize <= usize
inline fn muloXi4_genericSmall(comptime ST: type, a: ST, b: ST, overflow: *c_int) ST {
overflow.* = 0;
const min = math.minInt(ST);
var res: ST = a *% b;
// Hacker's Delight section Overflow subsection Multiplication
// case a=-2^{31}, b=-1 problem, because
// on some machines a*b = -2^{31} with overflow
// Then -2^{31}/-1 overflows and any result is possible.
// => check with a<0 and b=-2^{31}
if ((a < 0 and b == min) or (a != 0 and @divTrunc(res, a) != b))
overflow.* = 1;
return res;
}
inline fn muloXi4_genericFast(comptime ST: type, a: ST, b: ST, overflow: *c_int) ST {
overflow.* = 0;
const EST = switch (ST) {
i32 => i64,
i64 => i128,
i128 => i256,
else => unreachable,
};
const min = math.minInt(ST);
const max = math.maxInt(ST);
var res: EST = @as(EST, a) * @as(EST, b);
//invariant: -2^{bitwidth(EST)} < res < 2^{bitwidth(EST)-1}
if (res < min or max < res)
overflow.* = 1;
return @truncate(ST, res);
}
const native_endian = builtin.cpu.arch.endian();
const low = switch (native_endian) {
.Big => 1,
.Little => 0,
};
const high = 1 - low;
pub fn udivmod(comptime DoubleInt: type, a: DoubleInt, b: DoubleInt, maybe_rem: ?*DoubleInt) DoubleInt {
// @setRuntimeSafety(builtin.is_test);
const double_int_bits = @typeInfo(DoubleInt).Int.bits;
const single_int_bits = @divExact(double_int_bits, 2);
const SingleInt = std.meta.Int(.unsigned, single_int_bits);
const SignedDoubleInt = std.meta.Int(.signed, double_int_bits);
const Log2SingleInt = std.math.Log2Int(SingleInt);
const n = @bitCast([2]SingleInt, a);
const d = @bitCast([2]SingleInt, b);
var q: [2]SingleInt = undefined;
var r: [2]SingleInt = undefined;
var sr: c_uint = undefined;
// special cases, X is unknown, K != 0
if (n[high] == 0) {
if (d[high] == 0) {
// 0 X
// ---
// 0 X
if (maybe_rem) |rem| {
rem.* = n[low] % d[low];
}
return n[low] / d[low];
}
// 0 X
// ---
// K X
if (maybe_rem) |rem| {
rem.* = n[low];
}
return 0;
}
// n[high] != 0
if (d[low] == 0) {
if (d[high] == 0) {
// K X
// ---
// 0 0
if (maybe_rem) |rem| {
rem.* = n[high] % d[low];
}
return n[high] / d[low];
}
// d[high] != 0
if (n[low] == 0) {
// K 0
// ---
// K 0
if (maybe_rem) |rem| {
r[high] = n[high] % d[high];
r[low] = 0;
rem.* = @bitCast(DoubleInt, r);
}
return n[high] / d[high];
}
// K K
// ---
// K 0
if ((d[high] & (d[high] - 1)) == 0) {
// d is a power of 2
if (maybe_rem) |rem| {
r[low] = n[low];
r[high] = n[high] & (d[high] - 1);
rem.* = @bitCast(DoubleInt, r);
}
return n[high] >> @intCast(Log2SingleInt, @ctz(SingleInt, d[high]));
}
// K K
// ---
// K 0
sr = @bitCast(c_uint, @as(c_int, @clz(SingleInt, d[high])) - @as(c_int, @clz(SingleInt, n[high])));
// 0 <= sr <= single_int_bits - 2 or sr large
if (sr > single_int_bits - 2) {
if (maybe_rem) |rem| {
rem.* = a;
}
return 0;
}
sr += 1;
// 1 <= sr <= single_int_bits - 1
// q.all = a << (double_int_bits - sr);
q[low] = 0;
q[high] = n[low] << @intCast(Log2SingleInt, single_int_bits - sr);
// r.all = a >> sr;
r[high] = n[high] >> @intCast(Log2SingleInt, sr);
r[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
} else {
// d[low] != 0
if (d[high] == 0) {
// K X
// ---
// 0 K
if ((d[low] & (d[low] - 1)) == 0) {
// d is a power of 2
if (maybe_rem) |rem| {
rem.* = n[low] & (d[low] - 1);
}
if (d[low] == 1) {
return a;
}
sr = @ctz(SingleInt, d[low]);
q[high] = n[high] >> @intCast(Log2SingleInt, sr);
q[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
return @bitCast(DoubleInt, q);
}
// K X
// ---
// 0 K
sr = 1 + single_int_bits + @as(c_uint, @clz(SingleInt, d[low])) - @as(c_uint, @clz(SingleInt, n[high]));
// 2 <= sr <= double_int_bits - 1
// q.all = a << (double_int_bits - sr);
// r.all = a >> sr;
if (sr == single_int_bits) {
q[low] = 0;
q[high] = n[low];
r[high] = 0;
r[low] = n[high];
} else if (sr < single_int_bits) {
// 2 <= sr <= single_int_bits - 1
q[low] = 0;
q[high] = n[low] << @intCast(Log2SingleInt, single_int_bits - sr);
r[high] = n[high] >> @intCast(Log2SingleInt, sr);
r[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
} else {
// single_int_bits + 1 <= sr <= double_int_bits - 1
q[low] = n[low] << @intCast(Log2SingleInt, double_int_bits - sr);
q[high] = (n[high] << @intCast(Log2SingleInt, double_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr - single_int_bits));
r[high] = 0;
r[low] = n[high] >> @intCast(Log2SingleInt, sr - single_int_bits);
}
} else {
// K X
// ---
// K K
sr = @bitCast(c_uint, @as(c_int, @clz(SingleInt, d[high])) - @as(c_int, @clz(SingleInt, n[high])));
// 0 <= sr <= single_int_bits - 1 or sr large
if (sr > single_int_bits - 1) {
if (maybe_rem) |rem| {
rem.* = a;
}
return 0;
}
sr += 1;
// 1 <= sr <= single_int_bits
// q.all = a << (double_int_bits - sr);
// r.all = a >> sr;
q[low] = 0;
if (sr == single_int_bits) {
q[high] = n[low];
r[high] = 0;
r[low] = n[high];
} else {
r[high] = n[high] >> @intCast(Log2SingleInt, sr);
r[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
q[high] = n[low] << @intCast(Log2SingleInt, single_int_bits - sr);
}
}
}
// Not a special case
// q and r are initialized with:
// q.all = a << (double_int_bits - sr);
// r.all = a >> sr;
// 1 <= sr <= double_int_bits - 1
var carry: u32 = 0;
var r_all: DoubleInt = undefined;
while (sr > 0) : (sr -= 1) {
// r:q = ((r:q) << 1) | carry
r[high] = (r[high] << 1) | (r[low] >> (single_int_bits - 1));
r[low] = (r[low] << 1) | (q[high] >> (single_int_bits - 1));
q[high] = (q[high] << 1) | (q[low] >> (single_int_bits - 1));
q[low] = (q[low] << 1) | carry;
// carry = 0;
// if (r.all >= b)
// {
// r.all -= b;
// carry = 1;
// }
r_all = @bitCast(DoubleInt, r);
const s: SignedDoubleInt = @bitCast(SignedDoubleInt, b -% r_all -% 1) >> (double_int_bits - 1);
carry = @intCast(u32, s & 1);
r_all -= b & @bitCast(DoubleInt, s);
r = @bitCast([2]SingleInt, r_all);
}
const q_all = (@bitCast(DoubleInt, q) << 1) | carry;
if (maybe_rem) |rem| {
rem.* = r_all;
}
return q_all;
}
pub inline fn floatToInt(comptime I: type, a: anytype) I {
const Log2Int = math.Log2Int;
const Int = @import("std").meta.Int;
const F = @TypeOf(a);
const float_bits = @typeInfo(F).Float.bits;
const int_bits = @typeInfo(I).Int.bits;
const rep_t = Int(.unsigned, float_bits);
const sig_bits = math.floatMantissaBits(F);
const exp_bits = math.floatExponentBits(F);
const fractional_bits = floatFractionalBits(F);
// const implicit_bit = if (F != f80) (@as(rep_t, 1) << sig_bits) else 0;
const implicit_bit = @as(rep_t, 1) << sig_bits;
const max_exp = (1 << (exp_bits - 1));
const exp_bias = max_exp - 1;
const sig_mask = (@as(rep_t, 1) << sig_bits) - 1;
// Break a into sign, exponent, significand
const a_rep: rep_t = @bitCast(rep_t, a);
const negative = (a_rep >> (float_bits - 1)) != 0;
const exponent = @intCast(i32, (a_rep << 1) >> (sig_bits + 1)) - exp_bias;
const significand: rep_t = (a_rep & sig_mask) | implicit_bit;
// If the exponent is negative, the result rounds to zero.
if (exponent < 0) return 0;
// If the value is too large for the integer type, saturate.
switch (@typeInfo(I).Int.signedness) {
.unsigned => {
if (negative) return 0;
if (@intCast(c_uint, exponent) >= @minimum(int_bits, max_exp)) return math.maxInt(I);
},
.signed => if (@intCast(c_uint, exponent) >= @minimum(int_bits - 1, max_exp)) {
return if (negative) math.minInt(I) else math.maxInt(I);
},
}
// If 0 <= exponent < sig_bits, right shift to get the result.
// Otherwise, shift left.
var result: I = undefined;
if (exponent < fractional_bits) {
result = @intCast(I, significand >> @intCast(Log2Int(rep_t), fractional_bits - exponent));
} else {
result = @intCast(I, significand) << @intCast(Log2Int(I), exponent - fractional_bits);
}
if ((@typeInfo(I).Int.signedness == .signed) and negative)
return ~result +% 1;
return result;
}
/// Returns the number of fractional bits in the mantissa of floating point type T.
pub inline fn floatFractionalBits(comptime T: type) comptime_int {
comptime std.debug.assert(@typeInfo(T) == .Float);
// standard IEEE floats have an implicit 0.m or 1.m integer part
// f80 is special and has an explicitly stored bit in the MSB
// this function corresponds to `MANT_DIG - 1' from C
return switch (@typeInfo(T).Float.bits) {
16 => 10,
32 => 23,
64 => 52,
80 => 63,
128 => 112,
else => @compileError("unknown floating point type " ++ @typeName(T)),
};
}

View file

@ -5,6 +5,10 @@ const utils = @import("utils.zig");
const expect = @import("expect.zig");
const panic_utils = @import("panic.zig");
comptime {
_ = @import("compiler_rt.zig");
}
const ROC_BUILTINS = "roc_builtins";
const NUM = "num";
const STR = "str";
@ -81,8 +85,12 @@ comptime {
num.exportPow(T, ROC_BUILTINS ++ "." ++ NUM ++ ".pow_int.");
num.exportDivCeil(T, ROC_BUILTINS ++ "." ++ NUM ++ ".div_ceil.");
num.exportRoundF32(T, ROC_BUILTINS ++ "." ++ NUM ++ ".round_f32.");
num.exportRoundF64(T, ROC_BUILTINS ++ "." ++ NUM ++ ".round_f64.");
num.exportRound(f32, T, ROC_BUILTINS ++ "." ++ NUM ++ ".round_f32.");
num.exportRound(f64, T, ROC_BUILTINS ++ "." ++ NUM ++ ".round_f64.");
num.exportFloor(f32, T, ROC_BUILTINS ++ "." ++ NUM ++ ".floor_f32.");
num.exportFloor(f64, T, ROC_BUILTINS ++ "." ++ NUM ++ ".floor_f64.");
num.exportCeiling(f32, T, ROC_BUILTINS ++ "." ++ NUM ++ ".ceiling_f32.");
num.exportCeiling(f64, T, ROC_BUILTINS ++ "." ++ NUM ++ ".ceiling_f64.");
num.exportAddWithOverflow(T, ROC_BUILTINS ++ "." ++ NUM ++ ".add_with_overflow.");
num.exportAddOrPanic(T, ROC_BUILTINS ++ "." ++ NUM ++ ".add_or_panic.");
@ -122,6 +130,8 @@ comptime {
num.exportPow(T, ROC_BUILTINS ++ "." ++ NUM ++ ".pow.");
num.exportLog(T, ROC_BUILTINS ++ "." ++ NUM ++ ".log.");
num.exportFAbs(T, ROC_BUILTINS ++ "." ++ NUM ++ ".fabs.");
num.exportSqrt(T, ROC_BUILTINS ++ "." ++ NUM ++ ".sqrt.");
num.exportAddWithOverflow(T, ROC_BUILTINS ++ "." ++ NUM ++ ".add_with_overflow.");
num.exportSubWithOverflow(T, ROC_BUILTINS ++ "." ++ NUM ++ ".sub_with_overflow.");
@ -274,60 +284,3 @@ test "" {
testing.refAllDecls(@This());
}
// For unclear reasons, sometimes this function is not linked in on some machines.
// Therefore we provide it as LLVM bitcode and mark it as externally linked during our LLVM codegen
//
// Taken from
// https://github.com/ziglang/zig/blob/85755c51d529e7d9b406c6bdf69ce0a0f33f3353/lib/std/special/compiler_rt/muloti4.zig
//
// Thank you Zig Contributors!
// Export it as weak incase it is already linked in by something else.
comptime {
if (builtin.target.os.tag != .windows) {
@export(__muloti4, .{ .name = "__muloti4", .linkage = .Weak });
}
}
fn __muloti4(a: i128, b: i128, overflow: *c_int) callconv(.C) i128 {
// @setRuntimeSafety(std.builtin.is_test);
const min = @bitCast(i128, @as(u128, 1 << (128 - 1)));
const max = ~min;
overflow.* = 0;
const r = a *% b;
if (a == min) {
if (b != 0 and b != 1) {
overflow.* = 1;
}
return r;
}
if (b == min) {
if (a != 0 and a != 1) {
overflow.* = 1;
}
return r;
}
const sa = a >> (128 - 1);
const abs_a = (a ^ sa) -% sa;
const sb = b >> (128 - 1);
const abs_b = (b ^ sb) -% sb;
if (abs_a < 2 or abs_b < 2) {
return r;
}
if (sa == sb) {
if (abs_a > @divTrunc(max, abs_b)) {
overflow.* = 1;
}
} else {
if (abs_a > @divTrunc(min, -abs_b)) {
overflow.* = 1;
}
}
return r;
}

View file

@ -152,7 +152,7 @@ pub fn exportAtan(comptime T: type, comptime name: []const u8) void {
pub fn exportSin(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: T) callconv(.C) T {
return @sin(input);
return math.sin(input);
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
@ -161,7 +161,7 @@ pub fn exportSin(comptime T: type, comptime name: []const u8) void {
pub fn exportCos(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: T) callconv(.C) T {
return @cos(input);
return math.cos(input);
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
@ -170,25 +170,52 @@ pub fn exportCos(comptime T: type, comptime name: []const u8) void {
pub fn exportLog(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: T) callconv(.C) T {
return @log(input);
return math.ln(input);
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}
pub fn exportRoundF32(comptime T: type, comptime name: []const u8) void {
pub fn exportFAbs(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: f32) callconv(.C) T {
return @floatToInt(T, (@round(input)));
fn func(input: T) callconv(.C) T {
return math.absFloat(input);
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}
pub fn exportRoundF64(comptime T: type, comptime name: []const u8) void {
pub fn exportSqrt(comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: f64) callconv(.C) T {
return @floatToInt(T, (@round(input)));
fn func(input: T) callconv(.C) T {
return math.sqrt(input);
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}
pub fn exportRound(comptime F: type, comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: F) callconv(.C) T {
return @floatToInt(T, (math.round(input)));
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}
pub fn exportFloor(comptime F: type, comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: F) callconv(.C) T {
return @floatToInt(T, (math.floor(input)));
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });
}
pub fn exportCeiling(comptime F: type, comptime T: type, comptime name: []const u8) void {
comptime var f = struct {
fn func(input: F) callconv(.C) T {
return @floatToInt(T, (math.ceil(input)));
}
}.func;
@export(f, .{ .name = name ++ @typeName(T), .linkage = .Strong });

View file

@ -989,19 +989,11 @@ pow : Frac a, Frac a -> Frac a
## This process is known as [exponentiation by squaring](https://en.wikipedia.org/wiki/Exponentiation_by_squaring).
##
## For a [Frac] alternative to this function, which supports negative exponents,
## see #Num.exp.
## ```
## Num.exp 5 0
## see #Num.pow.
##
## Num.exp 5 1
## ## Warning
##
## Num.exp 5 2
##
## Num.exp 5 6
## ```
## ## Performance Details
##
## Be careful! It is very easy for this function to produce an answer
## It is very easy for this function to produce an answer
## so large it causes an overflow.
powInt : Int a, Int a -> Int a

View file

@ -267,9 +267,15 @@ pub const NUM_IS_INFINITE: IntrinsicName = float_intrinsic!("roc_builtins.num.is
pub const NUM_IS_FINITE: IntrinsicName = float_intrinsic!("roc_builtins.num.is_finite");
pub const NUM_LOG: IntrinsicName = float_intrinsic!("roc_builtins.num.log");
pub const NUM_POW: IntrinsicName = float_intrinsic!("roc_builtins.num.pow");
pub const NUM_FABS: IntrinsicName = float_intrinsic!("roc_builtins.num.fabs");
pub const NUM_SQRT: IntrinsicName = float_intrinsic!("roc_builtins.num.sqrt");
pub const NUM_POW_INT: IntrinsicName = int_intrinsic!("roc_builtins.num.pow_int");
pub const NUM_DIV_CEIL: IntrinsicName = int_intrinsic!("roc_builtins.num.div_ceil");
pub const NUM_CEILING_F32: IntrinsicName = int_intrinsic!("roc_builtins.num.ceiling_f32");
pub const NUM_CEILING_F64: IntrinsicName = int_intrinsic!("roc_builtins.num.ceiling_f64");
pub const NUM_FLOOR_F32: IntrinsicName = int_intrinsic!("roc_builtins.num.floor_f32");
pub const NUM_FLOOR_F64: IntrinsicName = int_intrinsic!("roc_builtins.num.floor_f64");
pub const NUM_ROUND_F32: IntrinsicName = int_intrinsic!("roc_builtins.num.round_f32");
pub const NUM_ROUND_F64: IntrinsicName = int_intrinsic!("roc_builtins.num.round_f64");

View file

@ -7,11 +7,12 @@ use inkwell::{
};
use roc_builtins::{
bitcode::{FloatWidth, IntWidth, IntrinsicName},
float_intrinsic, llvm_int_intrinsic,
llvm_int_intrinsic,
};
use super::build::{add_func, FunctionSpec};
#[allow(dead_code)]
fn add_float_intrinsic<'ctx, F>(
ctx: &'ctx Context,
module: &Module<'ctx>,
@ -111,18 +112,6 @@ pub(crate) fn add_intrinsics<'ctx>(ctx: &'ctx Context, module: &Module<'ctx>) {
i8_ptr_type.fn_type(&[], false),
);
add_float_intrinsic(ctx, module, &LLVM_LOG, |t| t.fn_type(&[t.into()], false));
add_float_intrinsic(ctx, module, &LLVM_POW, |t| {
t.fn_type(&[t.into(), t.into()], false)
});
add_float_intrinsic(ctx, module, &LLVM_FABS, |t| t.fn_type(&[t.into()], false));
add_float_intrinsic(ctx, module, &LLVM_SIN, |t| t.fn_type(&[t.into()], false));
add_float_intrinsic(ctx, module, &LLVM_COS, |t| t.fn_type(&[t.into()], false));
add_float_intrinsic(ctx, module, &LLVM_CEILING, |t| {
t.fn_type(&[t.into()], false)
});
add_float_intrinsic(ctx, module, &LLVM_FLOOR, |t| t.fn_type(&[t.into()], false));
add_int_intrinsic(ctx, module, &LLVM_ADD_WITH_OVERFLOW, |t| {
let fields = [t.into(), i1_type.into()];
ctx.struct_type(&fields, false)
@ -150,17 +139,6 @@ pub(crate) fn add_intrinsics<'ctx>(ctx: &'ctx Context, module: &Module<'ctx>) {
});
}
pub const LLVM_POW: IntrinsicName = float_intrinsic!("llvm.pow");
pub const LLVM_FABS: IntrinsicName = float_intrinsic!("llvm.fabs");
pub static LLVM_SQRT: IntrinsicName = float_intrinsic!("llvm.sqrt");
pub static LLVM_LOG: IntrinsicName = float_intrinsic!("llvm.log");
pub static LLVM_SIN: IntrinsicName = float_intrinsic!("llvm.sin");
pub static LLVM_COS: IntrinsicName = float_intrinsic!("llvm.cos");
pub static LLVM_CEILING: IntrinsicName = float_intrinsic!("llvm.ceil");
pub static LLVM_FLOOR: IntrinsicName = float_intrinsic!("llvm.floor");
pub static LLVM_ROUND: IntrinsicName = float_intrinsic!("llvm.round");
pub static LLVM_MEMSET_I64: &str = "llvm.memset.p0i8.i64";
pub static LLVM_MEMSET_I32: &str = "llvm.memset.p0i8.i32";

View file

@ -41,9 +41,13 @@ use crate::llvm::{
self, basic_type_from_layout, zig_num_parse_result_type, zig_to_int_checked_result_type,
},
intrinsics::{
LLVM_ADD_SATURATED, LLVM_ADD_WITH_OVERFLOW, LLVM_CEILING, LLVM_COS, LLVM_FABS, LLVM_FLOOR,
LLVM_LOG, LLVM_MUL_WITH_OVERFLOW, LLVM_POW, LLVM_ROUND, LLVM_SIN, LLVM_SQRT,
LLVM_SUB_SATURATED, LLVM_SUB_WITH_OVERFLOW,
// These instrinsics do not generate calls to libc and are safe to keep.
// If we find that any of them generate calls to libc on some platforms, we need to define them as zig bitcode.
LLVM_ADD_SATURATED,
LLVM_ADD_WITH_OVERFLOW,
LLVM_MUL_WITH_OVERFLOW,
LLVM_SUB_SATURATED,
LLVM_SUB_WITH_OVERFLOW,
},
refcounting::PointerToRefcount,
};
@ -1704,7 +1708,11 @@ fn build_float_binop<'ctx>(
NumLt => bd.build_float_compare(OLT, lhs, rhs, "float_lt").into(),
NumLte => bd.build_float_compare(OLE, lhs, rhs, "float_lte").into(),
NumDivFrac => bd.build_float_div(lhs, rhs, "div_float").into(),
NumPow => env.call_intrinsic(&LLVM_POW[float_width], &[lhs.into(), rhs.into()]),
NumPow => call_bitcode_fn(
env,
&[lhs.into(), rhs.into()],
&bitcode::NUM_POW[float_width],
),
_ => {
unreachable!("Unrecognized int binary operation: {:?}", op);
}
@ -2316,9 +2324,9 @@ fn build_float_unary_op<'a, 'ctx>(
// TODO: Handle different sized floats
match op {
NumNeg => bd.build_float_neg(arg, "negate_float").into(),
NumAbs => env.call_intrinsic(&LLVM_FABS[float_width], &[arg.into()]),
NumSqrtUnchecked => env.call_intrinsic(&LLVM_SQRT[float_width], &[arg.into()]),
NumLogUnchecked => env.call_intrinsic(&LLVM_LOG[float_width], &[arg.into()]),
NumAbs => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_FABS[float_width]),
NumSqrtUnchecked => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_SQRT[float_width]),
NumLogUnchecked => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_LOG[float_width]),
NumToFrac => {
let return_width = match layout_interner.get(layout).repr {
LayoutRepr::Builtin(Builtin::Float(return_width)) => return_width,
@ -2342,64 +2350,46 @@ fn build_float_unary_op<'a, 'ctx>(
}
}
NumCeiling => {
let (return_signed, return_type) = match layout_interner.get(layout).repr {
LayoutRepr::Builtin(Builtin::Int(int_width)) => (
int_width.is_signed(),
convert::int_type_from_int_width(env, int_width),
),
let int_width = match layout_interner.get(layout).repr {
LayoutRepr::Builtin(Builtin::Int(int_width)) => int_width,
_ => internal_error!("Ceiling return layout is not int: {:?}", layout),
};
let opcode = if return_signed {
InstructionOpcode::FPToSI
} else {
InstructionOpcode::FPToUI
};
env.builder.build_cast(
opcode,
env.call_intrinsic(&LLVM_CEILING[float_width], &[arg.into()]),
return_type,
"num_ceiling",
)
match float_width {
FloatWidth::F32 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_CEILING_F32[int_width])
}
FloatWidth::F64 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_CEILING_F64[int_width])
}
}
}
NumFloor => {
let (return_signed, return_type) = match layout_interner.get(layout).repr {
LayoutRepr::Builtin(Builtin::Int(int_width)) => (
int_width.is_signed(),
convert::int_type_from_int_width(env, int_width),
),
_ => internal_error!("Ceiling return layout is not int: {:?}", layout),
let int_width = match layout_interner.get(layout).repr {
LayoutRepr::Builtin(Builtin::Int(int_width)) => int_width,
_ => internal_error!("Floor return layout is not int: {:?}", layout),
};
let opcode = if return_signed {
InstructionOpcode::FPToSI
} else {
InstructionOpcode::FPToUI
};
env.builder.build_cast(
opcode,
env.call_intrinsic(&LLVM_FLOOR[float_width], &[arg.into()]),
return_type,
"num_floor",
)
match float_width {
FloatWidth::F32 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_FLOOR_F32[int_width])
}
FloatWidth::F64 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_FLOOR_F64[int_width])
}
}
}
NumRound => {
let (return_signed, return_type) = match layout_interner.get(layout).repr {
LayoutRepr::Builtin(Builtin::Int(int_width)) => (
int_width.is_signed(),
convert::int_type_from_int_width(env, int_width),
),
_ => internal_error!("Ceiling return layout is not int: {:?}", layout),
let int_width = match layout_interner.get(layout).repr {
LayoutRepr::Builtin(Builtin::Int(int_width)) => int_width,
_ => internal_error!("Round return layout is not int: {:?}", layout),
};
let opcode = if return_signed {
InstructionOpcode::FPToSI
} else {
InstructionOpcode::FPToUI
};
env.builder.build_cast(
opcode,
env.call_intrinsic(&LLVM_ROUND[float_width], &[arg.into()]),
return_type,
"num_round",
)
match float_width {
FloatWidth::F32 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_ROUND_F32[int_width])
}
FloatWidth::F64 => {
call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_ROUND_F64[int_width])
}
}
}
NumIsNan => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_IS_NAN[float_width]),
NumIsInfinite => {
@ -2408,8 +2398,8 @@ fn build_float_unary_op<'a, 'ctx>(
NumIsFinite => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_IS_FINITE[float_width]),
// trigonometry
NumSin => env.call_intrinsic(&LLVM_SIN[float_width], &[arg.into()]),
NumCos => env.call_intrinsic(&LLVM_COS[float_width], &[arg.into()]),
NumSin => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_SIN[float_width]),
NumCos => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_COS[float_width]),
NumAtan => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_ATAN[float_width]),
NumAcos => call_bitcode_fn(env, &[arg.into()], &bitcode::NUM_ACOS[float_width]),

View file

@ -32,7 +32,7 @@ use roc_module::symbol::{
};
use roc_mono::ir::{
CapturedSymbols, ExternalSpecializations, GlueLayouts, LambdaSetId, PartialProc, Proc,
ProcLayout, Procs, ProcsBase, UpdateModeIds,
ProcLayout, Procs, ProcsBase, UpdateModeIds, UsageTrackingMap,
};
use roc_mono::layout::LayoutInterner;
use roc_mono::layout::{
@ -5782,6 +5782,7 @@ fn make_specializations<'a>(
abilities: AbilitiesView::World(&world_abilities),
exposed_by_module,
derived_module: &derived_module,
struct_indexing: UsageTrackingMap::default(),
};
let mut procs = Procs::new_in(arena);
@ -5882,6 +5883,7 @@ fn build_pending_specializations<'a>(
abilities: AbilitiesView::Module(&abilities_store),
exposed_by_module,
derived_module: &derived_module,
struct_indexing: UsageTrackingMap::default(),
};
let layout_cache_snapshot = layout_cache.snapshot();
@ -6363,6 +6365,7 @@ fn load_derived_partial_procs<'a>(
abilities: AbilitiesView::World(world_abilities),
exposed_by_module,
derived_module,
struct_indexing: UsageTrackingMap::default(),
};
let partial_proc = match derived_expr {

View file

@ -1386,6 +1386,7 @@ pub struct Env<'a, 'i> {
pub abilities: AbilitiesView<'i>,
pub exposed_by_module: &'i ExposedByModule,
pub derived_module: &'i SharedDerivedModule,
pub struct_indexing: UsageTrackingMap<(Symbol, u64), Symbol>,
}
impl<'a, 'i> Env<'a, 'i> {
@ -4225,6 +4226,7 @@ pub fn with_hole<'a>(
// If this symbol is a raw value, find the real name we gave to its specialized usage.
if let ReuseSymbol::Value(_symbol) = can_reuse_symbol(
env,
layout_cache,
procs,
&roc_can::expr::Expr::Var(symbol, variable),
variable,
@ -4324,7 +4326,7 @@ pub fn with_hole<'a>(
OpaqueRef { argument, .. } => {
let (arg_var, loc_arg_expr) = *argument;
match can_reuse_symbol(env, procs, &loc_arg_expr.value, arg_var) {
match can_reuse_symbol(env, layout_cache, procs, &loc_arg_expr.value, arg_var) {
// Opaques decay to their argument.
ReuseSymbol::Value(symbol) => {
let real_name = procs.get_or_insert_symbol_specialization(
@ -4909,20 +4911,19 @@ pub fn with_hole<'a>(
RecordUpdate {
record_var,
symbol: structure,
updates,
ref updates,
..
} => {
use FieldType::*;
enum FieldType<'a> {
CopyExisting(u64),
CopyExisting,
UpdateExisting(&'a roc_can::expr::Field),
}
// Strategy: turn a record update into the creation of a new record.
// This has the benefit that we don't need to do anything special for reference
// counting
let sorted_fields_result = {
let mut layout_env = layout::Env::from_components(
layout_cache,
@ -4938,43 +4939,56 @@ pub fn with_hole<'a>(
Err(_) => return runtime_error(env, "Can't update record with improper layout"),
};
let mut field_layouts = Vec::with_capacity_in(sorted_fields.len(), env.arena);
let single_field_struct = sorted_fields.len() == 1;
let mut symbols = Vec::with_capacity_in(sorted_fields.len(), env.arena);
// The struct indexing generated by the current context
let mut current_struct_indexing = Vec::with_capacity_in(sorted_fields.len(), env.arena);
// The symbols that are used to create the new struct
let mut new_struct_symbols = Vec::with_capacity_in(sorted_fields.len(), env.arena);
// Information about the fields that are being updated
let mut fields = Vec::with_capacity_in(sorted_fields.len(), env.arena);
let mut index = 0;
for (label, _, opt_field_layout) in sorted_fields.iter() {
let record_index = (structure, index);
let mut current = 0;
for (label, _, opt_field_layout) in sorted_fields.into_iter() {
match opt_field_layout {
Err(_) => {
debug_assert!(!updates.contains_key(&label));
debug_assert!(!updates.contains_key(label));
// this was an optional field, and now does not exist!
// do not increment `current`!
// do not increment `index`!
}
Ok(field_layout) => {
field_layouts.push(field_layout);
Ok(_field_layout) => {
current_struct_indexing.push(record_index);
if let Some(field) = updates.get(&label) {
let field_symbol = possible_reuse_symbol_or_specialize(
// The struct with a single field is optimized in such a way that replacing later indexing will cause an incorrect IR.
// Thus, only insert these struct_indices if there is more than one field in the struct.
if !single_field_struct {
let original_struct_symbol = env.unique_symbol();
env.struct_indexing
.insert(record_index, original_struct_symbol);
}
if let Some(field) = updates.get(label) {
let new_struct_symbol = possible_reuse_symbol_or_specialize(
env,
procs,
layout_cache,
&field.loc_expr.value,
field.var,
);
new_struct_symbols.push(new_struct_symbol);
fields.push(UpdateExisting(field));
symbols.push(field_symbol);
} else {
fields.push(CopyExisting(current));
symbols.push(env.unique_symbol());
new_struct_symbols
.push(*env.struct_indexing.get(record_index).unwrap());
fields.push(CopyExisting);
}
current += 1;
index += 1;
}
}
}
let symbols = symbols.into_bump_slice();
let new_struct_symbols = new_struct_symbols.into_bump_slice();
let record_layout = layout_cache
.from_var(env.arena, record_var, env.subs)
@ -4985,8 +4999,8 @@ pub fn with_hole<'a>(
_ => arena.alloc([record_layout]),
};
if symbols.len() == 1 {
// TODO we can probably special-case this more, skippiing the generation of
if single_field_struct {
// TODO we can probably special-case this more, skipping the generation of
// UpdateExisting
let mut stmt = hole.clone();
@ -4994,7 +5008,7 @@ pub fn with_hole<'a>(
match what_to_do {
UpdateExisting(field) => {
substitute_in_exprs(env.arena, &mut stmt, assigned, symbols[0]);
substitute_in_exprs(env.arena, &mut stmt, assigned, new_struct_symbols[0]);
stmt = assign_to_symbol(
env,
@ -5002,11 +5016,11 @@ pub fn with_hole<'a>(
layout_cache,
field.var,
*field.loc_expr.clone(),
symbols[0],
new_struct_symbols[0],
stmt,
);
}
CopyExisting(_) => {
CopyExisting => {
unreachable!(
r"when a record has just one field and is updated, it must update that one field"
);
@ -5015,12 +5029,10 @@ pub fn with_hole<'a>(
stmt
} else {
let expr = Expr::Struct(symbols);
let expr = Expr::Struct(new_struct_symbols);
let mut stmt = Stmt::Let(assigned, expr, record_layout, hole);
let it = field_layouts.iter().zip(symbols.iter()).zip(fields);
for ((field_layout, symbol), what_to_do) in it {
for (new_struct_symbol, what_to_do) in new_struct_symbols.iter().zip(fields) {
match what_to_do {
UpdateExisting(field) => {
stmt = assign_to_symbol(
@ -5029,47 +5041,54 @@ pub fn with_hole<'a>(
layout_cache,
field.var,
*field.loc_expr.clone(),
*symbol,
*new_struct_symbol,
stmt,
);
}
CopyExisting(index) => {
let structure_needs_specialization =
procs.ability_member_aliases.get(structure).is_some()
|| procs.is_module_thunk(structure)
|| procs.is_imported_module_thunk(structure);
let specialized_structure_sym = if structure_needs_specialization {
// We need to specialize the record now; create a new one for it.
// TODO: reuse this symbol for all updates
env.unique_symbol()
} else {
// The record is already good.
structure
};
let access_expr = Expr::StructAtIndex {
structure: specialized_structure_sym,
index,
field_layouts,
};
stmt =
Stmt::Let(*symbol, access_expr, *field_layout, arena.alloc(stmt));
if structure_needs_specialization {
stmt = specialize_symbol(
env,
procs,
layout_cache,
Some(record_var),
specialized_structure_sym,
env.arena.alloc(stmt),
structure,
);
}
CopyExisting => {
// When a field is copied, the indexing symbol is already placed in new_struct_symbols
// Thus, we don't need additional logic here.
}
}
}
let structure_needs_specialization =
procs.ability_member_aliases.get(structure).is_some()
|| procs.is_module_thunk(structure)
|| procs.is_imported_module_thunk(structure);
let specialized_structure_sym = if structure_needs_specialization {
// We need to specialize the record now; create a new one for it.
env.unique_symbol()
} else {
// The record is already good.
structure
};
for record_index in current_struct_indexing.into_iter().rev() {
if let Some(symbol) = env.struct_indexing.get_used(&record_index) {
let layout = field_layouts[record_index.1 as usize];
let access_expr = Expr::StructAtIndex {
structure: specialized_structure_sym,
index: record_index.1,
field_layouts,
};
stmt = Stmt::Let(symbol, access_expr, layout, arena.alloc(stmt));
};
}
if structure_needs_specialization {
stmt = specialize_symbol(
env,
procs,
layout_cache,
Some(record_var),
specialized_structure_sym,
env.arena.alloc(stmt),
structure,
);
}
stmt
}
}
@ -5227,7 +5246,7 @@ pub fn with_hole<'a>(
// re-use that symbol, and don't define its value again
let mut result;
use ReuseSymbol::*;
match can_reuse_symbol(env, procs, &loc_expr.value, fn_var) {
match can_reuse_symbol(env, layout_cache, procs, &loc_expr.value, fn_var) {
LocalFunction(_) => {
unreachable!("if this was known to be a function, we would not be here")
}
@ -5685,22 +5704,28 @@ fn compile_struct_like<'a, L, UnusedLayout>(
// TODO how should function pointers be handled here?
use ReuseSymbol::*;
match take_elem_expr(index) {
Some((var, loc_expr)) => match can_reuse_symbol(env, procs, &loc_expr.value, var) {
Imported(symbol) | LocalFunction(symbol) | UnspecializedExpr(symbol) => {
elem_symbols.push(symbol);
can_elems.push(Field::FunctionOrUnspecialized(symbol, variable));
Some((var, loc_expr)) => {
match can_reuse_symbol(env, layout_cache, procs, &loc_expr.value, var) {
Imported(symbol) | LocalFunction(symbol) | UnspecializedExpr(symbol) => {
elem_symbols.push(symbol);
can_elems.push(Field::FunctionOrUnspecialized(symbol, variable));
}
Value(symbol) => {
let reusable = procs.get_or_insert_symbol_specialization(
env,
layout_cache,
symbol,
var,
);
elem_symbols.push(reusable);
can_elems.push(Field::ValueSymbol);
}
NotASymbol => {
elem_symbols.push(env.unique_symbol());
can_elems.push(Field::Field(var, *loc_expr));
}
}
Value(symbol) => {
let reusable =
procs.get_or_insert_symbol_specialization(env, layout_cache, symbol, var);
elem_symbols.push(reusable);
can_elems.push(Field::ValueSymbol);
}
NotASymbol => {
elem_symbols.push(env.unique_symbol());
can_elems.push(Field::Field(var, *loc_expr));
}
},
}
None => {
// this field was optional, but not given
continue;
@ -6816,7 +6841,7 @@ pub fn from_can<'a>(
store_specialized_expectation_lookups(env, [variable], &[spec_var]);
let symbol_is_reused = matches!(
can_reuse_symbol(env, procs, &loc_condition.value, variable),
can_reuse_symbol(env, layout_cache, procs, &loc_condition.value, variable),
ReuseSymbol::Value(_)
);
@ -7615,7 +7640,8 @@ enum ReuseSymbol {
fn can_reuse_symbol<'a>(
env: &mut Env<'a, '_>,
procs: &Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
procs: &mut Procs<'a>,
expr: &roc_can::expr::Expr,
expr_var: Variable,
) -> ReuseSymbol {
@ -7627,6 +7653,52 @@ fn can_reuse_symbol<'a>(
late_resolve_ability_specialization(env, *member, *specialization_id, expr_var)
}
Var(symbol, _) => *symbol,
RecordAccess {
record_var,
field,
loc_expr,
..
} => {
let sorted_fields_result = {
let mut layout_env = layout::Env::from_components(
layout_cache,
env.subs,
env.arena,
env.target_info,
);
layout::sort_record_fields(&mut layout_env, *record_var)
};
let sorted_fields = match sorted_fields_result {
Ok(fields) => fields,
Err(_) => unreachable!("Can't access record with improper layout"),
};
let index = sorted_fields
.into_iter()
.enumerate()
.find_map(|(current, (label, _, _))| (label == *field).then_some(current));
let struct_index = index.expect("field not in its own type");
let struct_symbol = possible_reuse_symbol_or_specialize(
env,
procs,
layout_cache,
&loc_expr.value,
*record_var,
);
match env
.struct_indexing
.get((struct_symbol, struct_index as u64))
{
Some(symbol) => *symbol,
None => {
return NotASymbol;
}
}
}
_ => return NotASymbol,
};
@ -7660,7 +7732,7 @@ fn possible_reuse_symbol_or_specialize<'a>(
expr: &roc_can::expr::Expr,
var: Variable,
) -> Symbol {
match can_reuse_symbol(env, procs, expr, var) {
match can_reuse_symbol(env, layout_cache, procs, expr, var) {
ReuseSymbol::Value(symbol) => {
procs.get_or_insert_symbol_specialization(env, layout_cache, symbol, var)
}
@ -7999,7 +8071,7 @@ fn assign_to_symbol<'a>(
result: Stmt<'a>,
) -> Stmt<'a> {
use ReuseSymbol::*;
match can_reuse_symbol(env, procs, &loc_arg.value, arg_var) {
match can_reuse_symbol(env, layout_cache, procs, &loc_arg.value, arg_var) {
Imported(original) | LocalFunction(original) | UnspecializedExpr(original) => {
// for functions we must make sure they are specialized correctly
specialize_symbol(
@ -9983,3 +10055,42 @@ where
answer
}
enum Usage {
Used,
Unused,
}
pub struct UsageTrackingMap<K, V> {
map: MutMap<K, (V, Usage)>,
}
impl<K, V> Default for UsageTrackingMap<K, V> {
fn default() -> Self {
Self {
map: MutMap::default(),
}
}
}
impl<K, V> UsageTrackingMap<K, V>
where
K: std::cmp::Eq + std::hash::Hash,
{
pub fn insert(&mut self, key: K, value: V) {
self.map.insert(key, (value, Usage::Unused));
}
pub fn get(&mut self, key: K) -> Option<&V> {
let (value, usage) = self.map.get_mut(&key)?;
*usage = Usage::Used;
Some(value)
}
fn get_used(&mut self, key: &K) -> Option<V> {
self.map.remove(key).and_then(|(value, usage)| match usage {
Usage::Used => Some(value),
Usage::Unused => None,
})
}
}

View file

@ -1842,7 +1842,7 @@ fn float_add_checked_fail() {
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen_dev"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn float_add_overflow() {
assert_evals_to!(
"1.7976931348623157e308 + 1.7976931348623157e308",

View file

@ -1,13 +1,13 @@
procedure Test.1 ():
let Test.7 : U8 = 1i64;
let Test.8 : U8 = 2i64;
let Test.6 : {U8, U8} = Struct {Test.7, Test.8};
ret Test.6;
let Test.10 : U8 = 1i64;
let Test.11 : U8 = 2i64;
let Test.9 : {U8, U8} = Struct {Test.10, Test.11};
ret Test.9;
procedure Test.0 ():
let Test.9 : {U8, U8} = CallByName Test.1;
let Test.3 : U8 = StructAtIndex 0 Test.9;
let Test.5 : {U8, U8} = CallByName Test.1;
let Test.4 : U8 = StructAtIndex 1 Test.5;
let Test.2 : List U8 = Array [Test.3, Test.4];
let Test.13 : {U8, U8} = CallByName Test.1;
let Test.4 : U8 = StructAtIndex 0 Test.13;
let Test.8 : {U8, U8} = CallByName Test.1;
let Test.6 : U8 = StructAtIndex 1 Test.8;
let Test.2 : List U8 = Array [Test.4, Test.6];
ret Test.2;

View file

@ -0,0 +1,47 @@
procedure List.3 (List.104, List.105, List.106):
let List.503 : {List U64, U64} = CallByName List.64 List.104 List.105 List.106;
let List.502 : List U64 = StructAtIndex 0 List.503;
ret List.502;
procedure List.6 (#Attr.2):
let List.501 : U64 = lowlevel ListLen #Attr.2;
ret List.501;
procedure List.64 (List.101, List.102, List.103):
let List.500 : U64 = CallByName List.6 List.101;
let List.497 : Int1 = CallByName Num.22 List.102 List.500;
if List.497 then
let List.498 : {List U64, U64} = CallByName List.67 List.101 List.102 List.103;
ret List.498;
else
let List.496 : {List U64, U64} = Struct {List.101, List.103};
ret List.496;
procedure List.67 (#Attr.2, #Attr.3, #Attr.4):
let List.499 : {List U64, U64} = lowlevel ListReplaceUnsafe #Attr.2 #Attr.3 #Attr.4;
ret List.499;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.281 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;
ret Num.281;
procedure Test.1 (Test.2):
let Test.6 : List U64 = StructAtIndex 0 Test.2;
let Test.8 : List U64 = StructAtIndex 1 Test.2;
let Test.10 : List U64 = StructAtIndex 2 Test.2;
let Test.13 : U64 = 8i64;
let Test.14 : U64 = 8i64;
let Test.9 : List U64 = CallByName List.3 Test.8 Test.13 Test.14;
let Test.11 : U64 = 7i64;
let Test.12 : U64 = 7i64;
let Test.7 : List U64 = CallByName List.3 Test.6 Test.11 Test.12;
let Test.5 : {List U64, List U64, List U64} = Struct {Test.7, Test.9, Test.10};
ret Test.5;
procedure Test.0 ():
let Test.15 : List U64 = Array [];
let Test.16 : List U64 = Array [];
let Test.17 : List U64 = Array [];
let Test.4 : {List U64, List U64, List U64} = Struct {Test.15, Test.16, Test.17};
let Test.3 : {List U64, List U64, List U64} = CallByName Test.1 Test.4;
ret Test.3;

View file

@ -3066,3 +3066,15 @@ fn drop_specialize_after_struct() {
"#
)
}
#[mono_test]
fn record_update() {
indoc!(
r#"
app "test" provides [main] to "./platform"
main = f {a: [], b: [], c:[]}
f : {a: List Nat, b: List Nat, c: List Nat} -> {a: List Nat, b: List Nat, c: List Nat}
f = \record -> {record & a: List.set record.a 7 7, b: List.set record.b 8 8}
"#
)
}

View file

@ -23,10 +23,10 @@ procedure Dep.0 ():
ret Dep.1;
procedure Test.0 ():
let Test.3 : Str = "http://www.example.com";
let Test.4 : {Str, Str} = CallByName Dep.0;
let Test.2 : Str = StructAtIndex 0 Test.4;
let #Derived_gen.0 : Str = StructAtIndex 1 Test.4;
let Test.5 : {Str, Str} = CallByName Dep.0;
let Test.2 : Str = StructAtIndex 0 Test.5;
let #Derived_gen.0 : Str = StructAtIndex 1 Test.5;
dec #Derived_gen.0;
let Test.1 : {Str, Str} = Struct {Test.2, Test.3};
let Test.4 : Str = "http://www.example.com";
let Test.1 : {Str, Str} = Struct {Test.2, Test.4};
ret Test.1;