Merge pull request #1013 from rtfeldman/astar-fixes

Astar fixes
This commit is contained in:
Richard Feldman 2021-02-20 23:36:12 -05:00 committed by GitHub
commit 2f479e9ffb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 649 additions and 326 deletions

View file

@ -86,6 +86,7 @@ test-rust:
FROM +build-rust
ARG RUSTC_WRAPPER=/usr/local/cargo/bin/sccache
ARG SCCACHE_DIR=/earthbuild/sccache_dir
ARG RUST_BACKTRACE=1
RUN cargo test --release
test-all:
@ -94,4 +95,4 @@ test-all:
BUILD +save-cache
BUILD +test-zig
BUILD +test-rust

View file

@ -228,7 +228,6 @@ mod cli_run {
}
#[test]
#[ignore]
#[serial(astar)]
fn run_astar_optimized_1() {
check_output_with_stdin(
@ -279,11 +278,11 @@ mod cli_run {
}
#[test]
#[serial(closure4)]
fn closure4() {
#[serial(closure)]
fn closure() {
check_output(
&example_file("benchmarks", "Closure4.roc"),
"closure4",
&example_file("benchmarks", "Closure.roc"),
"closure",
&[],
"",
true,

View file

@ -296,6 +296,10 @@ pub const RocDict = extern struct {
}
fn getKey(self: *const RocDict, index: usize, alignment: Alignment, key_width: usize, value_width: usize) Opaque {
if (key_width == 0) {
return null;
}
const offset = blk: {
if (alignment.keyFirst()) {
break :blk (index * key_width);
@ -329,6 +333,10 @@ pub const RocDict = extern struct {
}
fn getValue(self: *const RocDict, index: usize, alignment: Alignment, key_width: usize, value_width: usize) Opaque {
if (value_width == 0) {
return null;
}
const offset = blk: {
if (alignment.keyFirst()) {
break :blk (self.capacity() * key_width) + (index * value_width);
@ -492,6 +500,8 @@ pub fn dictRemove(input: RocDict, alignment: Alignment, key: Opaque, key_width:
if (dict.dict_entries_len == 0) {
const data_bytes = dict.capacity() * slotSize(key_width, value_width);
decref(std.heap.c_allocator, alignment, dict.dict_bytes, data_bytes);
output.* = RocDict.empty();
return;
}
output.* = dict;
@ -660,7 +670,9 @@ pub fn dictUnion(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_width
// we need an extra RC token for the key
inc_key(key);
inc_value(value);
// we know the newly added key is not a duplicate, so the `dec`s are unreachable
const dec_key = doNothing;
const dec_value = doNothing;
@ -702,7 +714,7 @@ pub fn dictIntersection(dict1: RocDict, dict2: RocDict, alignment: Alignment, ke
}
}
pub fn dictDifference(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, dec_key: Inc, dec_value: Inc, output: *RocDict) callconv(.C) void {
pub fn dictDifference(dict1: RocDict, dict2: RocDict, alignment: Alignment, key_width: usize, value_width: usize, hash_fn: HashFn, is_eq: EqFn, dec_key: Dec, dec_value: Dec, output: *RocDict) callconv(.C) void {
output.* = dict1.makeUnique(std.heap.c_allocator, alignment, key_width, value_width);
var i: usize = 0;
@ -748,8 +760,14 @@ pub fn setFromList(list: RocList, alignment: Alignment, key_width: usize, value_
}
const StepperCaller = fn (?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) void;
pub fn dictWalk(dict: RocDict, stepper: Opaque, stepper_caller: StepperCaller, accum: Opaque, alignment: Alignment, key_width: usize, value_width: usize, accum_width: usize, output: Opaque) callconv(.C) void {
@memcpy(output orelse unreachable, accum orelse unreachable, accum_width);
pub fn dictWalk(dict: RocDict, stepper: Opaque, stepper_caller: StepperCaller, accum: Opaque, alignment: Alignment, key_width: usize, value_width: usize, accum_width: usize, inc_key: Inc, inc_value: Inc, output: Opaque) callconv(.C) void {
// allocate space to write the result of the stepper into
// experimentally aliasing the accum and output pointers is not a good idea
const alloc: [*]u8 = @ptrCast([*]u8, std.heap.c_allocator.alloc(u8, accum_width) catch unreachable);
var b1 = output orelse unreachable;
var b2 = alloc;
@memcpy(b2, accum orelse unreachable, accum_width);
var i: usize = 0;
const size = dict.capacity();
@ -759,12 +777,19 @@ pub fn dictWalk(dict: RocDict, stepper: Opaque, stepper_caller: StepperCaller, a
const key = dict.getKey(i, alignment, key_width, value_width);
const value = dict.getValue(i, alignment, key_width, value_width);
stepper_caller(stepper, key, value, output, output);
stepper_caller(stepper, key, value, b2, b1);
const temp = b1;
b2 = b1;
b1 = temp;
},
else => {},
}
}
@memcpy(output orelse unreachable, b2, accum_width);
std.heap.c_allocator.free(alloc[0..accum_width]);
const data_bytes = dict.capacity() * slotSize(key_width, value_width);
decref(std.heap.c_allocator, alignment, dict.dict_bytes, data_bytes);
}

View file

@ -7,6 +7,9 @@ const Allocator = mem.Allocator;
const EqFn = fn (?[*]u8, ?[*]u8) callconv(.C) bool;
const Opaque = ?[*]u8;
const Inc = fn (?[*]u8) callconv(.C) void;
const Dec = fn (?[*]u8) callconv(.C) void;
pub const RocList = extern struct {
bytes: ?[*]u8,
length: usize,
@ -58,7 +61,7 @@ pub const RocList = extern struct {
}
// unfortunately, we have to clone
var new_list = RocList.allocate(allocator, self.length, alignment, element_width);
var new_list = RocList.allocate(allocator, alignment, self.length, element_width);
var old_bytes: [*]u8 = @ptrCast([*]u8, self.bytes);
var new_bytes: [*]u8 = @ptrCast([*]u8, new_list.bytes);
@ -149,7 +152,7 @@ pub fn listMapWithIndex(list: RocList, transform: Opaque, caller: Caller2, align
}
}
pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, element_width: usize) callconv(.C) RocList {
pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, element_width: usize, inc: Inc, dec: Dec) callconv(.C) RocList {
if (list.bytes) |source_ptr| {
const size = list.len();
var i: usize = 0;
@ -160,6 +163,7 @@ pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment:
while (i < size) : (i += 1) {
var keep = false;
const element = source_ptr + (i * element_width);
inc(element);
caller(transform, element, @ptrCast(?[*]u8, &keep));
if (keep) {
@ -167,29 +171,36 @@ pub fn listKeepIf(list: RocList, transform: Opaque, caller: Caller1, alignment:
kept += 1;
} else {
// TODO decrement the value?
dec(element);
}
}
output.length = kept;
// consume the input list
utils.decref(std.heap.c_allocator, alignment, list.bytes, size * element_width);
return output;
if (kept == 0) {
// if the output is empty, deallocate the space we made for the result
utils.decref(std.heap.c_allocator, alignment, output.bytes, size * element_width);
return RocList.empty();
} else {
output.length = kept;
return output;
}
} else {
return RocList.empty();
}
}
pub fn listKeepOks(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) callconv(.C) RocList {
return listKeepResult(list, RocResult.isOk, transform, caller, alignment, before_width, result_width, after_width);
pub fn listKeepOks(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize, inc_closure: Inc, dec_result: Dec) callconv(.C) RocList {
return listKeepResult(list, RocResult.isOk, transform, caller, alignment, before_width, result_width, after_width, inc_closure, dec_result);
}
pub fn listKeepErrs(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) callconv(.C) RocList {
return listKeepResult(list, RocResult.isErr, transform, caller, alignment, before_width, result_width, after_width);
pub fn listKeepErrs(list: RocList, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize, inc_closure: Inc, dec_result: Dec) callconv(.C) RocList {
return listKeepResult(list, RocResult.isErr, transform, caller, alignment, before_width, result_width, after_width, inc_closure, dec_result);
}
pub fn listKeepResult(list: RocList, is_good_constructor: fn (RocResult) bool, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize) RocList {
pub fn listKeepResult(list: RocList, is_good_constructor: fn (RocResult) bool, transform: Opaque, caller: Caller1, alignment: usize, before_width: usize, result_width: usize, after_width: usize, inc_closure: Inc, dec_result: Dec) RocList {
if (list.bytes) |source_ptr| {
const size = list.len();
var i: usize = 0;
@ -200,23 +211,31 @@ pub fn listKeepResult(list: RocList, is_good_constructor: fn (RocResult) bool, t
var kept: usize = 0;
while (i < size) : (i += 1) {
const element = source_ptr + (i * before_width);
caller(transform, element, temporary);
const before_element = source_ptr + (i * before_width);
inc_closure(transform);
caller(transform, before_element, temporary);
const result = utils.RocResult{ .bytes = temporary };
const after_element = temporary + @sizeOf(i64);
if (is_good_constructor(result)) {
@memcpy(target_ptr + (kept * after_width), temporary + @sizeOf(i64), after_width);
@memcpy(target_ptr + (kept * after_width), after_element, after_width);
kept += 1;
} else {
dec_result(temporary);
}
}
output.length = kept;
utils.decref(std.heap.c_allocator, alignment, list.bytes, size * before_width);
std.heap.c_allocator.free(temporary[0..result_width]);
return output;
if (kept == 0) {
utils.decref(std.heap.c_allocator, alignment, output.bytes, size * after_width);
return RocList.empty();
} else {
output.length = kept;
return output;
}
} else {
return RocList.empty();
}
@ -278,3 +297,30 @@ pub fn listContains(list: RocList, key: Opaque, key_width: usize, is_eq: EqFn) c
return false;
}
pub fn listRepeat(count: usize, alignment: usize, element: Opaque, element_width: usize, inc_n_element: Inc) callconv(.C) RocList {
if (count == 0) {
return RocList.empty();
}
const allocator = std.heap.c_allocator;
var output = RocList.allocate(allocator, alignment, count, element_width);
if (output.bytes) |target_ptr| {
var i: usize = 0;
const source = element orelse unreachable;
while (i < count) : (i += 1) {
@memcpy(target_ptr + i * element_width, source, element_width);
}
// TODO do all increments at once!
i = 0;
while (i < count) : (i += 1) {
inc_n_element(element);
}
return output;
} else {
unreachable;
}
}

View file

@ -14,6 +14,7 @@ comptime {
exportListFn(list.listKeepOks, "keep_oks");
exportListFn(list.listKeepErrs, "keep_errs");
exportListFn(list.listContains, "contains");
exportListFn(list.listRepeat, "repeat");
}
// Dict Module

View file

@ -1,3 +1,4 @@
const utils = @import("utils.zig");
const std = @import("std");
const mem = std.mem;
const always_inline = std.builtin.CallOptions.Modifier.always_inline;
@ -47,18 +48,7 @@ pub const RocStr = extern struct {
}
pub fn initBig(allocator: *Allocator, in_place: InPlace, number_of_chars: u64) RocStr {
const length = @sizeOf(usize) + number_of_chars;
var new_bytes: []usize = allocator.alloc(usize, length) catch unreachable;
if (in_place == InPlace.InPlace) {
new_bytes[0] = @intCast(usize, number_of_chars);
} else {
const v: isize = std.math.minInt(isize);
new_bytes[0] = @bitCast(usize, v);
}
var first_element = @ptrCast([*]align(@alignOf(usize)) u8, new_bytes);
first_element += @sizeOf(usize);
const first_element = utils.allocateWithRefcount(allocator, @sizeOf(usize), number_of_chars);
return RocStr{
.str_bytes = first_element,
@ -833,8 +823,10 @@ pub fn strConcatC(result_in_place: InPlace, arg1: RocStr, arg2: RocStr) callconv
fn strConcat(allocator: *Allocator, result_in_place: InPlace, arg1: RocStr, arg2: RocStr) RocStr {
if (arg1.isEmpty()) {
// the second argument is borrowed, so we must increment its refcount before returning
return RocStr.clone(allocator, result_in_place, arg2);
} else if (arg2.isEmpty()) {
// the first argument is owned, so we can return it without cloning
return RocStr.clone(allocator, result_in_place, arg1);
} else {
const combined_length = arg1.len() + arg2.len();

View file

@ -16,25 +16,25 @@ pub fn decref(
var bytes = bytes_or_null orelse return;
const usizes: [*]usize = @ptrCast([*]usize, @alignCast(8, bytes));
const isizes: [*]isize = @ptrCast([*]isize, @alignCast(8, bytes));
const refcount = (usizes - 1)[0];
const refcount = (isizes - 1)[0];
const refcount_isize = @bitCast(isize, refcount);
switch (alignment) {
16 => {
if (refcount == REFCOUNT_ONE) {
if (refcount == REFCOUNT_ONE_ISIZE) {
allocator.free((bytes - 16)[0 .. 16 + data_bytes]);
} else if (refcount_isize < 0) {
(usizes - 1)[0] = refcount + 1;
(isizes - 1)[0] = refcount - 1;
}
},
else => {
// NOTE enums can currently have an alignment of < 8
if (refcount == REFCOUNT_ONE) {
if (refcount == REFCOUNT_ONE_ISIZE) {
allocator.free((bytes - 8)[0 .. 8 + data_bytes]);
} else if (refcount_isize < 0) {
(usizes - 1)[0] = refcount + 1;
(isizes - 1)[0] = refcount - 1;
}
},
}
@ -72,11 +72,11 @@ pub fn allocateWithRefcount(
var new_bytes: []align(8) u8 = allocator.alignedAlloc(u8, 8, length) catch unreachable;
var as_usize_array = @ptrCast([*]usize, new_bytes);
var as_usize_array = @ptrCast([*]isize, new_bytes);
if (result_in_place) {
as_usize_array[0] = @intCast(usize, number_of_slots);
as_usize_array[0] = @intCast(isize, number_of_slots);
} else {
as_usize_array[0] = REFCOUNT_ONE;
as_usize_array[0] = REFCOUNT_ONE_ISIZE;
}
var as_u8_array = @ptrCast([*]u8, new_bytes);

View file

@ -67,3 +67,4 @@ pub const LIST_KEEP_ERRS: &str = "roc_builtins.list.keep_errs";
pub const LIST_WALK: &str = "roc_builtins.list.walk";
pub const LIST_WALK_BACKWARDS: &str = "roc_builtins.list.walk_backwards";
pub const LIST_CONTAINS: &str = "roc_builtins.list.contains";
pub const LIST_REPEAT: &str = "roc_builtins.list.repeat";

View file

@ -2042,7 +2042,7 @@ fn dict_get(symbol: Symbol, var_store: &mut VarStore) -> Def {
let arg_dict = Symbol::ARG_1;
let arg_key = Symbol::ARG_2;
let temp_record = Symbol::ARG_3;
let temp_record = Symbol::DICT_GET_RESULT;
let bool_var = var_store.fresh();
let flag_var = var_store.fresh();

View file

@ -207,6 +207,15 @@ fn build_transform_caller_help<'a, 'ctx, 'env>(
function_value
}
pub fn build_inc_n_wrapper<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
layout: &Layout<'a>,
n: u64,
) -> FunctionValue<'ctx> {
build_rc_wrapper(env, layout_ids, layout, Mode::Inc(n))
}
pub fn build_inc_wrapper<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,

View file

@ -3595,9 +3595,7 @@ fn run_low_level<'a, 'ctx, 'env>(
let list_len = load_symbol(scope, &args[0]).into_int_value();
let (elem, elem_layout) = load_symbol_and_layout(scope, &args[1]);
let inplace = get_inplace_from_layout(layout);
list_repeat(env, inplace, parent, list_len, elem, elem_layout)
list_repeat(env, layout_ids, list_len, elem, elem_layout)
}
ListReverse => {
// List.reverse : List elem -> List elem

View file

@ -747,6 +747,9 @@ pub fn dict_walk<'a, 'ctx, 'env>(
let output_ptr = builder.build_alloca(accum_bt, "output_ptr");
let inc_key_fn = build_inc_wrapper(env, layout_ids, key_layout);
let inc_value_fn = build_inc_wrapper(env, layout_ids, value_layout);
call_void_bitcode_fn(
env,
&[
@ -758,6 +761,8 @@ pub fn dict_walk<'a, 'ctx, 'env>(
key_width.into(),
value_width.into(),
accum_width.into(),
inc_key_fn.as_global_value().as_pointer_value().into(),
inc_value_fn.as_global_value().as_pointer_value().into(),
env.builder.build_bitcast(output_ptr, u8_ptr, "to_opaque"),
],
&bitcode::DICT_WALK,

View file

@ -1,6 +1,7 @@
#![allow(clippy::too_many_arguments)]
use crate::llvm::bitcode::{
build_eq_wrapper, build_transform_caller, call_bitcode_fn, call_void_bitcode_fn,
build_dec_wrapper, build_eq_wrapper, build_inc_wrapper, build_transform_caller,
call_bitcode_fn, call_void_bitcode_fn,
};
use crate::llvm::build::{
allocate_with_refcount_help, build_num_binop, cast_basic_basic, complex_bitcast, Env, InPlace,
@ -53,90 +54,43 @@ pub fn list_single<'a, 'ctx, 'env>(
/// List.repeat : Int, elem -> List elem
pub fn list_repeat<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
inplace: InPlace,
parent: FunctionValue<'ctx>,
layout_ids: &mut LayoutIds<'a>,
list_len: IntValue<'ctx>,
elem: BasicValueEnum<'ctx>,
elem_layout: &Layout<'a>,
element: BasicValueEnum<'ctx>,
element_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let ctx = env.context;
// list_len > 0
// We have to do a loop below, continuously adding the `elem`
// to the output list `List elem` until we have reached the
// number of repeats. This `comparison` is used to check
// if we need to do any looping; because if we dont, then we
// dont need to allocate memory for the index or the check
// if index != 0
let comparison = builder.build_int_compare(
IntPredicate::UGT,
list_len,
ctx.i64_type().const_int(0, false),
"atleastzero",
let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic);
let element_ptr = builder.build_alloca(element.get_type(), "element_ptr");
env.builder.build_store(element_ptr, element);
let element_width = env
.ptr_int()
.const_int(element_layout.stack_size(env.ptr_bytes) as u64, false);
let alignment = element_layout.alignment_bytes(env.ptr_bytes);
let alignment_iv = env.ptr_int().const_int(alignment as u64, false);
let inc_element_fn = build_inc_wrapper(env, layout_ids, element_layout);
let output = call_bitcode_fn(
env,
&[
list_len.into(),
alignment_iv.into(),
env.builder.build_bitcast(element_ptr, u8_ptr, "to_u8_ptr"),
element_width.into(),
inc_element_fn.as_global_value().as_pointer_value().into(),
],
bitcode::LIST_REPEAT,
);
let build_then = || {
// Allocate space for the new array that we'll copy into.
let list_ptr = allocate_list(env, inplace, elem_layout, list_len);
// TODO check if malloc returned null; if so, runtime error for OOM!
let index_name = "#index";
let start_alloca = builder.build_alloca(ctx.i64_type(), index_name);
// Start at the last element in the list.
let last_elem_index = builder.build_int_sub(
list_len,
ctx.i64_type().const_int(1, false),
"lastelemindex",
);
builder.build_store(start_alloca, last_elem_index);
let loop_bb = ctx.append_basic_block(parent, "loop");
builder.build_unconditional_branch(loop_bb);
builder.position_at_end(loop_bb);
// #index = #index - 1
let curr_index = builder
.build_load(start_alloca, index_name)
.into_int_value();
let next_index =
builder.build_int_sub(curr_index, ctx.i64_type().const_int(1, false), "nextindex");
builder.build_store(start_alloca, next_index);
let elem_ptr =
unsafe { builder.build_in_bounds_gep(list_ptr, &[curr_index], "load_index") };
// Mutate the new array in-place to change the element.
builder.build_store(elem_ptr, elem);
// #index != 0
let end_cond = builder.build_int_compare(
IntPredicate::NE,
ctx.i64_type().const_int(0, false),
curr_index,
"loopcond",
);
let after_bb = ctx.append_basic_block(parent, "afterloop");
builder.build_conditional_branch(end_cond, loop_bb, after_bb);
builder.position_at_end(after_bb);
store_list(env, list_ptr, list_len)
};
let build_else = || empty_polymorphic_list(env);
let struct_type = collection(ctx, env.ptr_bytes);
build_basic_phi2(
env,
parent,
comparison,
build_then,
build_else,
BasicTypeEnum::StructType(struct_type),
complex_bitcast(
env.builder,
output,
collection(env.context, env.ptr_bytes).into(),
"from_i128",
)
}
@ -1028,6 +982,9 @@ pub fn list_keep_if<'a, 'ctx, 'env>(
let alignment = element_layout.alignment_bytes(env.ptr_bytes);
let alignment_iv = env.ptr_int().const_int(alignment as u64, false);
let inc_element_fn = build_inc_wrapper(env, layout_ids, element_layout);
let dec_element_fn = build_dec_wrapper(env, layout_ids, element_layout);
let output = call_bitcode_fn(
env,
&[
@ -1037,6 +994,8 @@ pub fn list_keep_if<'a, 'ctx, 'env>(
stepper_caller.into(),
alignment_iv.into(),
element_width.into(),
inc_element_fn.as_global_value().as_pointer_value().into(),
dec_element_fn.as_global_value().as_pointer_value().into(),
],
&bitcode::LIST_KEEP_IF,
);
@ -1138,6 +1097,9 @@ pub fn list_keep_result<'a, 'ctx, 'env>(
let alignment = before_layout.alignment_bytes(env.ptr_bytes);
let alignment_iv = env.ptr_int().const_int(alignment as u64, false);
let inc_closure = build_inc_wrapper(env, layout_ids, transform_layout);
let dec_result_fn = build_dec_wrapper(env, layout_ids, result_layout);
let output = call_bitcode_fn(
env,
&[
@ -1149,6 +1111,8 @@ pub fn list_keep_result<'a, 'ctx, 'env>(
before_width.into(),
result_width.into(),
after_width.into(),
inc_closure.as_global_value().as_pointer_value().into(),
dec_result_fn.as_global_value().as_pointer_value().into(),
],
op,
);
@ -1532,7 +1496,7 @@ where
"bounds_check",
);
let after_loop_bb = ctx.append_basic_block(parent, "after_outer_loop");
let after_loop_bb = ctx.append_basic_block(parent, "after_outer_loop_1");
builder.build_conditional_branch(condition, loop_bb, after_loop_bb);
builder.position_at_end(after_loop_bb);
@ -1599,7 +1563,7 @@ where
// #index < end
let loop_end_cond = bounds_check_comparison(builder, next_index, end);
let after_loop_bb = ctx.append_basic_block(parent, "after_outer_loop");
let after_loop_bb = ctx.append_basic_block(parent, "after_outer_loop_2");
builder.build_conditional_branch(loop_end_cond, loop_bb, after_loop_bb);
builder.position_at_end(after_loop_bb);

View file

@ -285,6 +285,7 @@ fn modify_refcount_struct<'a, 'ctx, 'env>(
value: BasicValueEnum<'ctx>,
layouts: &[Layout<'a>],
mode: Mode,
when_recursive: &WhenRecursive<'a>,
) {
let wrapper_struct = value.into_struct_value();
@ -295,7 +296,15 @@ fn modify_refcount_struct<'a, 'ctx, 'env>(
.build_extract_value(wrapper_struct, i as u32, "decrement_struct_field")
.unwrap();
modify_refcount_layout(env, parent, layout_ids, mode, field_ptr, field_layout);
modify_refcount_layout_help(
env,
parent,
layout_ids,
mode,
when_recursive,
field_ptr,
field_layout,
);
}
}
}
@ -330,9 +339,9 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>(
fn modify_refcount_builtin<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
parent: FunctionValue<'ctx>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
value: BasicValueEnum<'ctx>,
layout: &Layout<'a>,
builtin: &Builtin<'a>,
@ -342,30 +351,17 @@ fn modify_refcount_builtin<'a, 'ctx, 'env>(
match builtin {
List(memory_mode, element_layout) => {
let wrapper_struct = value.into_struct_value();
if element_layout.contains_refcounted() {
let ptr_type =
basic_type_from_layout(env.arena, env.context, element_layout, env.ptr_bytes)
.ptr_type(AddressSpace::Generic);
let (len, ptr) = load_list(env.builder, wrapper_struct, ptr_type);
let loop_fn = |_index, element| {
modify_refcount_layout(env, parent, layout_ids, mode, element, element_layout);
};
incrementing_elem_loop(
env.builder,
env.context,
parent,
ptr,
len,
"modify_rc_index",
loop_fn,
);
}
if let MemoryMode::Refcounted = memory_mode {
modify_refcount_list(env, layout_ids, mode, layout, wrapper_struct);
modify_refcount_list(
env,
layout_ids,
mode,
when_recursive,
layout,
element_layout,
wrapper_struct,
);
}
}
Set(element_layout) => {
@ -380,6 +376,7 @@ fn modify_refcount_builtin<'a, 'ctx, 'env>(
env,
layout_ids,
mode,
when_recursive,
layout,
key_layout,
value_layout,
@ -404,13 +401,45 @@ fn modify_refcount_layout<'a, 'ctx, 'env>(
mode: Mode,
value: BasicValueEnum<'ctx>,
layout: &Layout<'a>,
) {
modify_refcount_layout_help(
env,
parent,
layout_ids,
mode,
&WhenRecursive::Unreachable,
value,
layout,
);
}
#[derive(Clone, Debug, PartialEq)]
enum WhenRecursive<'a> {
Unreachable,
Loop(UnionLayout<'a>),
}
fn modify_refcount_layout_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
parent: FunctionValue<'ctx>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
value: BasicValueEnum<'ctx>,
layout: &Layout<'a>,
) {
use Layout::*;
match layout {
Builtin(builtin) => {
modify_refcount_builtin(env, parent, layout_ids, mode, value, layout, builtin)
}
Builtin(builtin) => modify_refcount_builtin(
env,
layout_ids,
mode,
when_recursive,
value,
layout,
builtin,
),
Union(variant) => {
use UnionLayout::*;
@ -425,6 +454,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>(
env,
layout_ids,
mode,
&WhenRecursive::Loop(variant.clone()),
tags,
value.into_pointer_value(),
true,
@ -440,6 +470,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>(
env,
layout_ids,
mode,
&WhenRecursive::Loop(variant.clone()),
&*env.arena.alloc([other_fields]),
value.into_pointer_value(),
true,
@ -453,6 +484,7 @@ fn modify_refcount_layout<'a, 'ctx, 'env>(
env,
layout_ids,
mode,
&WhenRecursive::Loop(variant.clone()),
&*env.arena.alloc([*fields]),
value.into_pointer_value(),
true,
@ -465,13 +497,16 @@ fn modify_refcount_layout<'a, 'ctx, 'env>(
env,
layout_ids,
mode,
&WhenRecursive::Loop(variant.clone()),
tags,
value.into_pointer_value(),
false,
);
}
NonRecursive(tags) => modify_refcount_union(env, layout_ids, mode, tags, value),
NonRecursive(tags) => {
modify_refcount_union(env, layout_ids, mode, when_recursive, tags, value)
}
}
}
Closure(_, closure_layout, _) => {
@ -483,11 +518,12 @@ fn modify_refcount_layout<'a, 'ctx, 'env>(
.build_extract_value(wrapper_struct, 1, "modify_rc_closure_data")
.unwrap();
modify_refcount_layout(
modify_refcount_layout_help(
env,
parent,
layout_ids,
mode,
when_recursive,
field_ptr,
&closure_layout.as_block_of_memory_layout(),
)
@ -495,12 +531,45 @@ fn modify_refcount_layout<'a, 'ctx, 'env>(
}
Struct(layouts) => {
modify_refcount_struct(env, parent, layout_ids, value, layouts, mode);
modify_refcount_struct(
env,
parent,
layout_ids,
value,
layouts,
mode,
when_recursive,
);
}
PhantomEmptyStruct => {}
RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"),
Layout::RecursivePointer => match when_recursive {
WhenRecursive::Unreachable => {
unreachable!("recursion pointers should never be hashed directly")
}
WhenRecursive::Loop(union_layout) => {
let layout = Layout::Union(union_layout.clone());
let bt = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes);
// cast the i64 pointer to a pointer to block of memory
let field_cast = env
.builder
.build_bitcast(value, bt, "i64_to_opaque")
.into_pointer_value();
modify_refcount_layout_help(
env,
parent,
layout_ids,
mode,
when_recursive,
field_cast.into(),
&layout,
)
}
},
FunctionPointer(_, _) | Pointer(_) => {}
}
@ -510,7 +579,9 @@ fn modify_refcount_list<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
layout: &Layout<'a>,
element_layout: &Layout<'a>,
original_wrapper: StructValue<'ctx>,
) {
let block = env.builder.get_insert_block().expect("to be in a function");
@ -531,7 +602,15 @@ fn modify_refcount_list<'a, 'ctx, 'env>(
let basic_type = basic_type_from_layout(env.arena, env.context, &layout, env.ptr_bytes);
let function_value = build_header(env, basic_type, mode, &fn_name);
modify_refcount_list_help(env, mode, layout, function_value);
modify_refcount_list_help(
env,
layout_ids,
mode,
when_recursive,
layout,
element_layout,
function_value,
);
function_value
}
@ -553,8 +632,11 @@ fn mode_to_call_mode(function: FunctionValue<'_>, mode: Mode) -> CallMode<'_> {
fn modify_refcount_list_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
layout: &Layout<'a>,
element_layout: &Layout<'a>,
fn_val: FunctionValue<'ctx>,
) {
let builder = env.builder;
@ -593,6 +675,36 @@ fn modify_refcount_list_help<'a, 'ctx, 'env>(
builder.position_at_end(modification_block);
if element_layout.contains_refcounted() {
let ptr_type =
basic_type_from_layout(env.arena, env.context, element_layout, env.ptr_bytes)
.ptr_type(AddressSpace::Generic);
let (len, ptr) = load_list(env.builder, original_wrapper, ptr_type);
let loop_fn = |_index, element| {
modify_refcount_layout_help(
env,
parent,
layout_ids,
mode,
when_recursive,
element,
element_layout,
);
};
incrementing_elem_loop(
env.builder,
env.context,
parent,
ptr,
len,
"modify_rc_index",
loop_fn,
);
}
let refcount_ptr = PointerToRefcount::from_list_wrapper(env, original_wrapper);
let call_mode = mode_to_call_mode(fn_val, mode);
refcount_ptr.modify(call_mode, layout, env);
@ -701,10 +813,12 @@ fn modify_refcount_str_help<'a, 'ctx, 'env>(
builder.build_return(None);
}
#[allow(clippy::too_many_arguments)]
fn modify_refcount_dict<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
layout: &Layout<'a>,
key_layout: &Layout<'a>,
value_layout: &Layout<'a>,
@ -732,6 +846,7 @@ fn modify_refcount_dict<'a, 'ctx, 'env>(
env,
layout_ids,
mode,
when_recursive,
layout,
key_layout,
value_layout,
@ -749,15 +864,23 @@ fn modify_refcount_dict<'a, 'ctx, 'env>(
call_help(env, function, mode, original_wrapper.into(), call_name);
}
#[allow(clippy::too_many_arguments)]
fn modify_refcount_dict_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
layout: &Layout<'a>,
key_layout: &Layout<'a>,
value_layout: &Layout<'a>,
fn_val: FunctionValue<'ctx>,
) {
debug_assert_eq!(
when_recursive,
&WhenRecursive::Unreachable,
"TODO pipe when_recursive through the dict key/value inc/dec"
);
let builder = env.builder;
let ctx = env.context;
@ -784,7 +907,7 @@ fn modify_refcount_dict_help<'a, 'ctx, 'env>(
.into_int_value();
// the block we'll always jump to when we're done
let cont_block = ctx.append_basic_block(parent, "modify_rc_str_cont");
let cont_block = ctx.append_basic_block(parent, "modify_rc_dict_cont");
let modification_block = ctx.append_basic_block(parent, "modify_rc");
let is_non_empty = builder.build_int_compare(
@ -894,6 +1017,7 @@ fn build_rec_union<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
fields: &'a [&'a [Layout<'a>]],
value: PointerValue<'ctx>,
is_nullable: bool,
@ -920,7 +1044,15 @@ fn build_rec_union<'a, 'ctx, 'env>(
.into();
let function_value = build_header(env, basic_type, mode, &fn_name);
build_rec_union_help(env, layout_ids, mode, fields, function_value, is_nullable);
build_rec_union_help(
env,
layout_ids,
mode,
when_recursive,
fields,
function_value,
is_nullable,
);
env.builder.position_at_end(block);
env.builder
@ -937,6 +1069,7 @@ fn build_rec_union_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
tags: &[&[Layout<'a>]],
fn_val: FunctionValue<'ctx>,
is_nullable: bool,
@ -1093,7 +1226,15 @@ fn build_rec_union_help<'a, 'ctx, 'env>(
refcount_ptr.modify(call_mode, &layout, env);
for (field, field_layout) in deferred_nonrec {
modify_refcount_layout(env, parent, layout_ids, mode, field, field_layout);
modify_refcount_layout_help(
env,
parent,
layout_ids,
mode,
when_recursive,
field,
field_layout,
);
}
let call_name = pick("recursive_tag_increment", "recursive_tag_decrement");
@ -1208,6 +1349,7 @@ fn modify_refcount_union<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
fields: &'a [&'a [Layout<'a>]],
value: BasicValueEnum<'ctx>,
) {
@ -1231,7 +1373,14 @@ fn modify_refcount_union<'a, 'ctx, 'env>(
let basic_type = block_of_memory(env.context, &layout, env.ptr_bytes);
let function_value = build_header(env, basic_type, mode, &fn_name);
modify_refcount_union_help(env, layout_ids, mode, fields, function_value);
modify_refcount_union_help(
env,
layout_ids,
mode,
when_recursive,
fields,
function_value,
);
function_value
}
@ -1248,6 +1397,7 @@ fn modify_refcount_union_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
mode: Mode,
when_recursive: &WhenRecursive<'a>,
tags: &[&[Layout<'a>]],
fn_val: FunctionValue<'ctx>,
) {
@ -1332,7 +1482,15 @@ fn modify_refcount_union_help<'a, 'ctx, 'env>(
.build_extract_value(wrapper_struct, i as u32, "modify_tag_field")
.unwrap();
modify_refcount_layout(env, parent, layout_ids, mode, field_ptr, field_layout);
modify_refcount_layout_help(
env,
parent,
layout_ids,
mode,
when_recursive,
field_ptr,
field_layout,
);
}
}

View file

@ -912,23 +912,24 @@ define_builtins! {
2 DICT_EMPTY: "empty"
3 DICT_SINGLETON: "singleton"
4 DICT_GET: "get"
5 DICT_INSERT: "insert"
6 DICT_LEN: "len"
5 DICT_GET_RESULT: "#get_result" // symbol used in the definition of Dict.get
6 DICT_WALK: "walk"
7 DICT_INSERT: "insert"
8 DICT_LEN: "len"
// This should not be exposed to users, its for testing the
// hash function ONLY
7 DICT_TEST_HASH: "hashTestOnly"
9 DICT_TEST_HASH: "hashTestOnly"
8 DICT_REMOVE: "remove"
9 DICT_CONTAINS: "contains"
10 DICT_KEYS: "keys"
11 DICT_VALUES: "values"
10 DICT_REMOVE: "remove"
11 DICT_CONTAINS: "contains"
12 DICT_KEYS: "keys"
13 DICT_VALUES: "values"
12 DICT_UNION: "union"
13 DICT_INTERSECTION: "intersection"
14 DICT_DIFFERENCE: "difference"
14 DICT_UNION: "union"
15 DICT_INTERSECTION: "intersection"
16 DICT_DIFFERENCE: "difference"
15 DICT_WALK: "walk"
}
7 SET: "Set" => {

View file

@ -610,7 +610,9 @@ impl<'a> BorrowInfState<'a> {
}
pub fn foreign_borrow_signature(arena: &Bump, arity: usize) -> &[bool] {
let all = bumpalo::vec![in arena; false; arity];
// NOTE this means that Roc is responsible for cleaning up resources;
// the host cannot (currently) take ownership
let all = bumpalo::vec![in arena; true; arity];
all.into_bump_slice()
}
@ -632,16 +634,16 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
ListSet => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]),
ListSetInPlace => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]),
ListGetUnsafe => arena.alloc_slice_copy(&[borrowed, irrelevant]),
ListConcat | StrConcat => arena.alloc_slice_copy(&[owned, borrowed]),
ListConcat | StrConcat => arena.alloc_slice_copy(&[borrowed, borrowed]),
StrSplit => arena.alloc_slice_copy(&[borrowed, borrowed]),
ListSingle => arena.alloc_slice_copy(&[irrelevant]),
ListRepeat => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
ListRepeat => arena.alloc_slice_copy(&[irrelevant, borrowed]),
ListReverse => arena.alloc_slice_copy(&[owned]),
ListPrepend => arena.alloc_slice_copy(&[owned, owned]),
StrJoinWith => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
StrJoinWith => arena.alloc_slice_copy(&[borrowed, borrowed]),
ListJoin => arena.alloc_slice_copy(&[irrelevant]),
ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, irrelevant]),
ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, irrelevant]),
ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, borrowed]),
ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]),
ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]),
ListWalkBackwards => arena.alloc_slice_copy(&[owned, irrelevant, owned]),
@ -651,9 +653,11 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
// List.append should own its first argument
ListAppend => arena.alloc_slice_copy(&[borrowed, owned]),
Eq | NotEq | And | Or | NumAdd | NumAddWrap | NumAddChecked | NumSub | NumSubWrap
| NumSubChecked | NumMul | NumMulWrap | NumMulChecked | NumGt | NumGte | NumLt | NumLte
| NumCompare | NumDivUnchecked | NumRemUnchecked | NumPow | NumPowInt | NumBitwiseAnd
Eq | NotEq => arena.alloc_slice_copy(&[borrowed, borrowed]),
And | Or | NumAdd | NumAddWrap | NumAddChecked | NumSub | NumSubWrap | NumSubChecked
| NumMul | NumMulWrap | NumMulChecked | NumGt | NumGte | NumLt | NumLte | NumCompare
| NumDivUnchecked | NumRemUnchecked | NumPow | NumPowInt | NumBitwiseAnd
| NumBitwiseXor => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumCeiling | NumFloor

View file

@ -729,8 +729,12 @@ impl<'a> Context<'a> {
layout,
} => {
// TODO this combines parts of Let and Switch. Did this happen correctly?
let mut case_live_vars = collect_stmt(stmt, &self.jp_live_vars, MutSet::default());
let mut case_live_vars = collect_stmt(pass, &self.jp_live_vars, MutSet::default());
case_live_vars.extend(collect_stmt(fail, &self.jp_live_vars, MutSet::default()));
// the result of an invoke should not be touched in the fail branch
// but it should be present in the pass branch (otherwise it would be dead)
debug_assert!(case_live_vars.contains(symbol));
case_live_vars.remove(symbol);
let fail = {
@ -758,9 +762,50 @@ impl<'a> Context<'a> {
layout: layout.clone(),
};
let stmt = self.arena.alloc(invoke);
let cont = self.arena.alloc(invoke);
(stmt, case_live_vars)
use crate::ir::CallType;
let stmt = match &call.call_type {
CallType::LowLevel { op } => {
let ps = crate::borrow::lowlevel_borrow_signature(self.arena, *op);
self.add_dec_after_lowlevel(call.arguments, ps, cont, &case_live_vars)
}
CallType::Foreign { .. } => {
let ps = crate::borrow::foreign_borrow_signature(
self.arena,
call.arguments.len(),
);
self.add_dec_after_lowlevel(call.arguments, ps, cont, &case_live_vars)
}
CallType::ByName {
name, full_layout, ..
} => {
// get the borrow signature
match self.param_map.get_symbol(*name, full_layout.clone()) {
Some(ps) => self.add_dec_after_application(
call.arguments,
ps,
cont,
&case_live_vars,
),
None => self.add_inc_before_consume_all(
call.arguments,
cont,
&case_live_vars,
),
}
}
CallType::ByPointer { .. } => {
self.add_inc_before_consume_all(call.arguments, cont, &case_live_vars)
}
};
let mut invoke_live_vars = case_live_vars;
occuring_variables_call(call, &mut invoke_live_vars);
(stmt, invoke_live_vars)
}
Join {
id: j,

View file

@ -5183,10 +5183,7 @@ fn store_pattern_help<'a>(
return StorePattern::NotProductive(stmt);
}
AppliedTag {
arguments,
layout,
tag_name,
..
arguments, layout, ..
} => {
let wrapped = Wrapped::from_layout(layout);
let write_tag = wrapped == Wrapped::MultiTagUnion;
@ -5241,12 +5238,6 @@ fn store_pattern_help<'a>(
match store_pattern_help(env, procs, layout_cache, argument, symbol, stmt) {
StorePattern::Productive(new) => {
is_productive = true;
println!(
"Access {:?}.{:?} {:?}",
tag_name.clone(),
outer_symbol,
index
);
stmt = new;
// only if we bind one of its (sub)fields to a used name should we
// extract the field
@ -5948,7 +5939,8 @@ fn call_by_name<'a>(
debug_assert_eq!(
arg_layouts.len(),
field_symbols.len(),
"see call_by_name for background (scroll down a bit)"
"see call_by_name for background (scroll down a bit), function is {:?}",
proc_name,
);
let call = self::Call {
@ -5999,7 +5991,8 @@ fn call_by_name<'a>(
debug_assert_eq!(
arg_layouts.len(),
field_symbols.len(),
"see call_by_name for background (scroll down a bit)"
"see call_by_name for background (scroll down a bit), function is {:?}",
proc_name,
);
let call = self::Call {
@ -6521,8 +6514,10 @@ fn from_can_pattern_help<'a>(
debug_assert_eq!(
arguments.len(),
argument_layouts[1..].len(),
"{:?}",
tag_name
"The {:?} tag got {} arguments, but its layout expects {}!",
tag_name,
arguments.len(),
argument_layouts[1..].len(),
);
let it = argument_layouts[1..].iter();

View file

@ -877,7 +877,7 @@ impl<'a> Builtin<'a> {
/// Number of machine words in an empty one of these
pub const STR_WORDS: u32 = 2;
pub const DICT_WORDS: u32 = 6;
pub const DICT_WORDS: u32 = 3;
pub const SET_WORDS: u32 = Builtin::DICT_WORDS; // Set is an alias for Dict with {} for value
pub const LIST_WORDS: u32 = 2;

View file

@ -646,13 +646,13 @@ mod test_mono {
let Test.4 = lowlevel DictEmpty ;
ret Test.4;
procedure Dict.6 (#Attr.2):
procedure Dict.8 (#Attr.2):
let Test.3 = lowlevel DictSize #Attr.2;
ret Test.3;
procedure Test.0 ():
let Test.2 = FunctionPointer Dict.2;
let Test.1 = CallByName Dict.6 Test.2;
let Test.1 = CallByName Dict.8 Test.2;
ret Test.1;
"#
),

View file

@ -1,6 +1,6 @@
interface AStar exposes [ findPath, Model, initialModel ] imports [Quicksort]
interface AStar exposes [ findPath, Model, initialModel, cheapestOpen, takeStep, reconstructPath ] imports [Quicksort]
findPath = \costFn, moveFn, start, end ->
findPath = \costFn, moveFn, start, end ->
astar costFn moveFn end (initialModel start)
Model position :
@ -14,9 +14,9 @@ Model position :
initialModel : position -> Model position
initialModel = \start ->
{
evaluated : Set.empty,
evaluated : Set.empty,
openSet : Set.singleton start,
costs : Dict.singleton start 0,
costs : Dict.singleton start 0,
cameFrom : Dict.empty
}
@ -50,44 +50,28 @@ reconstructPath = \cameFrom, goal ->
updateCost : position, position, Model position -> Model position
updateCost = \current, neighbor, model ->
newCameFrom =
Dict.insert model.cameFrom neighbor current
newCosts =
Dict.insert model.costs neighbor distanceTo
distanceTo =
reconstructPath newCameFrom neighbor
|> List.len
|> Num.toFloat
newModel =
{ model &
costs: newCosts,
cameFrom: newCameFrom
}
when Dict.get model.costs neighbor is
Err _ ->
newCameFrom =
Dict.insert model.cameFrom neighbor current
newCosts =
Dict.insert model.costs neighbor distanceTo
distanceTo =
reconstructPath newCameFrom neighbor
|> List.len
|> Num.toFloat
{ model &
costs: newCosts,
cameFrom: newCameFrom
}
newModel
Ok previousDistance ->
newCameFrom =
Dict.insert model.cameFrom neighbor current
newCosts =
Dict.insert model.costs neighbor distanceTo
distanceTo =
reconstructPath newCameFrom neighbor
|> List.len
|> Num.toFloat
newModel =
{ model &
costs: newCosts,
cameFrom: newCameFrom
}
if distanceTo < previousDistance then
newModel
@ -126,3 +110,27 @@ astar = \costFn, moveFn, goal, model ->
Set.walk newNeighbors (\n, m -> updateCost current n m) modelWithNeighbors
astar costFn moveFn goal modelWithCosts
takeStep = \moveFn, _goal, model, current ->
modelPopped =
{ model &
openSet: Set.remove model.openSet current,
evaluated: Set.insert model.evaluated current,
}
neighbors = moveFn current
newNeighbors = Set.difference neighbors modelPopped.evaluated
modelWithNeighbors = { modelPopped & openSet: Set.union modelPopped.openSet newNeighbors }
# a lot goes wrong here
modelWithCosts =
Set.walk newNeighbors (\n, m -> updateCost current n m) modelWithNeighbors
modelWithCosts

View file

@ -3,20 +3,18 @@ app "astar-tests"
imports [base.Task, AStar]
provides [ main ] to base
fromList : List a -> Set a
fromList = \list -> List.walk list (\x, a -> Set.insert a x) Set.empty
main : Task.Task {} []
main =
Task.after Task.getInt \n ->
when n is
1 ->
Task.putLine (showBool test1)
Task.putLine (showBool test1)
_ ->
ns = Str.fromInt n
Task.putLine "No test \(ns)"
# Task.after Task.getInt \n ->
# when n is
# 1 ->
# Task.putLine (showBool test1)
#
# _ ->
# ns = Str.fromInt n
# Task.putLine "No test \(ns)"
showBool : Bool -> Str
showBool = \b ->
@ -26,17 +24,17 @@ showBool = \b ->
test1 : Bool
test1 =
example1 == [3, 4]
example1 == [2, 4]
example1 : List I64
example1 =
step : I64 -> Set I64
step = \n ->
when n is
1 -> fromList [ 2,3 ]
2 -> fromList [4]
3 -> fromList [4]
_ -> fromList []
1 -> Set.fromList [ 2,3 ]
2 -> Set.fromList [4]
3 -> Set.fromList [4]
_ -> Set.fromList []
cost : I64, I64 -> F64
cost = \_, _ -> 1

View file

@ -0,0 +1,58 @@
app "closure"
packages { base: "platform" }
imports [base.Task]
provides [ main ] to base
# see https://github.com/rtfeldman/roc/issues/985
main : Task.Task {} []
main = closure1 {}
|> Task.after (\_ -> closure2 {})
|> Task.after (\_ -> closure2 {})
|> Task.after (\_ -> closure2 {})
# ---
closure1 : {} -> Task.Task {} []
closure1 = \_ ->
Task.succeed (foo toUnitBorrowed "a long string such that it's malloced")
|> Task.map (\_ -> {})
toUnitBorrowed = \x -> Str.countGraphemes x
foo = \f, x -> f x
# ---
closure2 : {} -> Task.Task {} []
closure2 = \_ ->
x : Str
x = "a long string such that it's malloced"
Task.succeed {}
|> Task.map (\_ -> x)
|> Task.map toUnit
toUnit = (\_ -> {})
# ---
closure3 : {} -> Task.Task {} []
closure3 = \_ ->
x : Str
x = "a long string such that it's malloced"
Task.succeed {}
|> Task.after (\_ -> Task.succeed x |> Task.map (\_ -> {}))
# ---
closure4 : {} -> Task.Task {} []
closure4 = \_ ->
x : Str
x = "a long string such that it's malloced"
Task.succeed {}
|> Task.after (\_ -> Task.succeed x)
|> Task.map (\_ -> {})

View file

@ -1,15 +0,0 @@
app "closure1"
packages { base: "platform" }
imports [base.Task]
provides [ main ] to base
# see https://github.com/rtfeldman/roc/issues/985
main : Task.Task {} []
main =
Task.succeed (foo toUnitBorrowed "a long string such that it's malloced")
|> Task.map (\_ -> {})
toUnitBorrowed = \x -> Str.countGraphemes x
foo = \f, x -> f x

View file

@ -1,17 +0,0 @@
app "closure2"
packages { base: "platform" }
imports [base.Task]
provides [ main ] to base
# see https://github.com/rtfeldman/roc/issues/985
main : Task.Task {} []
main =
x : Str
x = "a long string such that it's malloced"
Task.succeed {}
|> Task.map (\_ -> x)
|> Task.map toUnit
toUnit = (\_ -> {})

View file

@ -1,14 +0,0 @@
app "closure3"
packages { base: "platform" }
imports [base.Task]
provides [ main ] to base
# see https://github.com/rtfeldman/roc/issues/985
main : Task.Task {} []
main =
x : Str
x = "a long string such that it's malloced"
Task.succeed {}
|> Task.after (\_ -> Task.succeed x |> Task.map (\_ -> {}))

View file

@ -1,16 +0,0 @@
app "closure4"
packages { base: "platform" }
imports [base.Task]
provides [ main ] to base
# see https://github.com/rtfeldman/roc/issues/985
main : Task.Task {} []
main =
x : Str
x = "a long string such that it's malloced"
Task.succeed {}
|> Task.after (\_ -> Task.succeed x)
|> Task.map (\_ -> {})

View file

@ -0,0 +1,77 @@
interface Quicksort exposes [ sortBy, show ] imports []
show : List I64 -> Str
show = \list ->
if List.isEmpty list then
"[]"
else
content =
list
|> List.map Str.fromInt
|> Str.joinWith ", "
"[ \(content) ]"
sortBy : List a, (a -> Num *) -> List a
sortBy = \list, toComparable ->
sortWith list (\x, y -> Num.compare (toComparable x) (toComparable y))
Order a : a, a -> [ LT, GT, EQ ]
sortWith : List a, (a, a -> [ LT, GT, EQ ]) -> List a
sortWith = \list, order ->
n = List.len list
quicksortHelp list order 0 (n - 1)
quicksortHelp : List a, Order a, Nat, Nat -> List a
quicksortHelp = \list, order, low, high ->
if low < high then
when partition low high list order is
Pair partitionIndex partitioned ->
partitioned
|> quicksortHelp order low (partitionIndex - 1)
|> quicksortHelp order (partitionIndex + 1) high
else
list
partition : Nat, Nat, List a, Order a -> [ Pair Nat (List a) ]
partition = \low, high, initialList, order ->
when List.get initialList high is
Ok pivot ->
when partitionHelp (low - 1) low initialList order high pivot is
Pair newI newList ->
Pair (newI + 1) (swap (newI + 1) high newList)
Err _ ->
Pair (low - 1) initialList
partitionHelp : Nat, Nat, List c, Order c, Nat, c -> [ Pair Nat (List c) ]
partitionHelp = \i, j, list, order, high, pivot ->
if j < high then
when List.get list j is
Ok value ->
when order value pivot is
LT | EQ ->
partitionHelp (i + 1) (j + 1) (swap (i + 1) j list) order high pivot
GT ->
partitionHelp i (j + 1) list order high pivot
Err _ ->
Pair i list
else
Pair i list
swap : Nat, Nat, List a -> List a
swap = \i, j, list ->
when Pair (List.get list i) (List.get list j) is
Pair (Ok atI) (Ok atJ) ->
list
|> List.set i atJ
|> List.set j atI
_ ->
[]