Merge remote-tracking branch 'origin/trunk' into type-start-parse-error

This commit is contained in:
Richard Feldman 2021-03-30 23:10:43 -04:00
commit a1de6925c8
42 changed files with 1581 additions and 480 deletions

View file

@ -4,6 +4,8 @@ const RocResult = utils.RocResult;
const mem = std.mem;
const Allocator = mem.Allocator;
const TAG_WIDTH = 8;
const EqFn = fn (?[*]u8, ?[*]u8) callconv(.C) bool;
const Opaque = ?[*]u8;
@ -502,6 +504,49 @@ pub fn listWalkBackwards(list: RocList, stepper: Opaque, stepper_caller: Caller2
utils.decref(std.heap.c_allocator, alignment, list.bytes, data_bytes);
}
pub fn listWalkUntil(list: RocList, stepper: Opaque, stepper_caller: Caller2, accum: Opaque, alignment: usize, element_width: usize, accum_width: usize, dec: Dec, output: Opaque) callconv(.C) void {
// [ Continue a, Stop a ]
const CONTINUE: usize = 0;
if (accum_width == 0) {
return;
}
if (list.isEmpty()) {
@memcpy(output orelse unreachable, accum orelse unreachable, accum_width);
return;
}
const alloc: [*]u8 = @ptrCast([*]u8, std.heap.c_allocator.alloc(u8, TAG_WIDTH + accum_width) catch unreachable);
@memcpy(alloc + TAG_WIDTH, accum orelse unreachable, accum_width);
if (list.bytes) |source_ptr| {
var i: usize = 0;
const size = list.len();
while (i < size) : (i += 1) {
const element = source_ptr + i * element_width;
stepper_caller(stepper, element, alloc + TAG_WIDTH, alloc);
const usizes: [*]usize = @ptrCast([*]usize, @alignCast(8, alloc));
if (usizes[0] != 0) {
// decrement refcount of the remaining items
i += 1;
while (i < size) : (i += 1) {
dec(source_ptr + i * element_width);
}
break;
}
}
}
@memcpy(output orelse unreachable, alloc + TAG_WIDTH, accum_width);
std.heap.c_allocator.free(alloc[0 .. TAG_WIDTH + accum_width]);
const data_bytes = list.len() * element_width;
utils.decref(std.heap.c_allocator, alignment, list.bytes, data_bytes);
}
// List.contains : List k, k -> Bool
pub fn listContains(list: RocList, key: Opaque, key_width: usize, is_eq: EqFn) callconv(.C) bool {
if (list.bytes) |source_ptr| {

View file

@ -12,6 +12,7 @@ comptime {
exportListFn(list.listMapWithIndex, "map_with_index");
exportListFn(list.listKeepIf, "keep_if");
exportListFn(list.listWalk, "walk");
exportListFn(list.listWalkUntil, "walkUntil");
exportListFn(list.listWalkBackwards, "walk_backwards");
exportListFn(list.listKeepOks, "keep_oks");
exportListFn(list.listKeepErrs, "keep_errs");

View file

@ -195,7 +195,7 @@ interface List2
## * Even when copying is faster, other list operations may still be slightly slower with persistent data structures. For example, even if it were a persistent data structure, #List.map, #List.fold, and #List.keepIf would all need to traverse every element in the list and build up the result from scratch. These operations are all
## * Roc's compiler optimizes many list operations into in-place mutations behind the scenes, depending on how the list is being used. For example, #List.map, #List.keepIf, and #List.set can all be optimized to perform in-place mutations.
## * If possible, it is usually best for performance to use large lists in a way where the optimizer can turn them into in-place mutations. If this is not possible, a persistent data structure might be faster - but this is a rare enough scenario that it would not be good for the average Roc program's performance if this were the way #List worked by default. Instead, you can look outside Roc's standard modules for an implementation of a persistent data structure - likely built using #List under the hood!
List elem : @List elem
List elem : [ @List elem ]
## Initialize

View file

@ -1,7 +1,11 @@
interface Set2
exposes [ empty, isEmpty, len, add, drop, map ]
interface Set
exposes [ Set, empty, isEmpty, len, add, drop, map ]
imports []
## Set
## A Set is an unordered collection of unique elements.
Set elem : [ @Set elem ]
## An empty set.
empty : Set *

View file

@ -70,6 +70,7 @@ pub const LIST_KEEP_IF: &str = "roc_builtins.list.keep_if";
pub const LIST_KEEP_OKS: &str = "roc_builtins.list.keep_oks";
pub const LIST_KEEP_ERRS: &str = "roc_builtins.list.keep_errs";
pub const LIST_WALK: &str = "roc_builtins.list.walk";
pub const LIST_WALK_UNTIL: &str = "roc_builtins.list.walkUntil";
pub const LIST_WALK_BACKWARDS: &str = "roc_builtins.list.walk_backwards";
pub const LIST_CONTAINS: &str = "roc_builtins.list.contains";
pub const LIST_REPEAT: &str = "roc_builtins.list.repeat";

View file

@ -771,6 +771,34 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
),
);
fn until_type(content: SolvedType) -> SolvedType {
// [ LT, EQ, GT ]
SolvedType::TagUnion(
vec![
(TagName::Global("Continue".into()), vec![content.clone()]),
(TagName::Global("Stop".into()), vec![content]),
],
Box::new(SolvedType::EmptyTagUnion),
)
}
// walkUntil : List elem, (elem -> accum -> [ Continue accum, Stop accum ]), accum -> accum
add_type(
Symbol::LIST_WALK_UNTIL,
top_level_function(
vec![
list_type(flex(TVAR1)),
closure(
vec![flex(TVAR1), flex(TVAR2)],
TVAR3,
Box::new(until_type(flex(TVAR2))),
),
flex(TVAR2),
],
Box::new(flex(TVAR2)),
),
);
// keepIf : List elem, (elem -> Bool) -> List elem
add_type(
Symbol::LIST_KEEP_IF,

View file

@ -89,6 +89,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option<Def>
LIST_KEEP_ERRS=> list_keep_errs,
LIST_WALK => list_walk,
LIST_WALK_BACKWARDS => list_walk_backwards,
LIST_WALK_UNTIL => list_walk_until,
DICT_TEST_HASH => dict_hash_test_only,
DICT_LEN => dict_len,
DICT_EMPTY => dict_empty,
@ -231,6 +232,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap<Symbol, Def> {
Symbol::LIST_KEEP_ERRS=> list_keep_errs,
Symbol::LIST_WALK => list_walk,
Symbol::LIST_WALK_BACKWARDS => list_walk_backwards,
Symbol::LIST_WALK_UNTIL => list_walk_until,
Symbol::DICT_TEST_HASH => dict_hash_test_only,
Symbol::DICT_LEN => dict_len,
Symbol::DICT_EMPTY => dict_empty,
@ -2094,60 +2096,17 @@ fn list_join(symbol: Symbol, var_store: &mut VarStore) -> Def {
/// List.walk : List elem, (elem -> accum -> accum), accum -> accum
fn list_walk(symbol: Symbol, var_store: &mut VarStore) -> Def {
let list_var = var_store.fresh();
let func_var = var_store.fresh();
let accum_var = var_store.fresh();
let body = RunLowLevel {
op: LowLevel::ListWalk,
args: vec![
(list_var, Var(Symbol::ARG_1)),
(func_var, Var(Symbol::ARG_2)),
(accum_var, Var(Symbol::ARG_3)),
],
ret_var: accum_var,
};
defn(
symbol,
vec![
(list_var, Symbol::ARG_1),
(func_var, Symbol::ARG_2),
(accum_var, Symbol::ARG_3),
],
var_store,
body,
accum_var,
)
lowlevel_3(symbol, LowLevel::ListWalk, var_store)
}
/// List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum
fn list_walk_backwards(symbol: Symbol, var_store: &mut VarStore) -> Def {
let list_var = var_store.fresh();
let func_var = var_store.fresh();
let accum_var = var_store.fresh();
lowlevel_3(symbol, LowLevel::ListWalkBackwards, var_store)
}
let body = RunLowLevel {
op: LowLevel::ListWalkBackwards,
args: vec![
(list_var, Var(Symbol::ARG_1)),
(func_var, Var(Symbol::ARG_2)),
(accum_var, Var(Symbol::ARG_3)),
],
ret_var: accum_var,
};
defn(
symbol,
vec![
(list_var, Symbol::ARG_1),
(func_var, Symbol::ARG_2),
(accum_var, Symbol::ARG_3),
],
var_store,
body,
accum_var,
)
/// List.walkUntil : List elem, (elem, accum -> [ Continue accum, Stop accum ]), accum -> accum
fn list_walk_until(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_3(symbol, LowLevel::ListWalkUntil, var_store)
}
/// List.sum : List (Num a) -> Num a

View file

@ -8,7 +8,7 @@ use crate::llvm::build_list::{
allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains,
list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map,
list_map2, list_map3, list_map_with_index, list_prepend, list_product, list_repeat,
list_reverse, list_set, list_single, list_sum, list_walk, list_walk_backwards,
list_reverse, list_set, list_single, list_sum, list_walk_help,
};
use crate::llvm::build_str::{
str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8,
@ -3879,57 +3879,30 @@ fn run_low_level<'a, 'ctx, 'env>(
list_contains(env, layout_ids, elem, elem_layout, list)
}
ListWalk => {
debug_assert_eq!(args.len(), 3);
let (list, list_layout) = load_symbol_and_layout(scope, &args[0]);
let (func, func_layout) = load_symbol_and_layout(scope, &args[1]);
let (default, default_layout) = load_symbol_and_layout(scope, &args[2]);
match list_layout {
Layout::Builtin(Builtin::EmptyList) => default,
Layout::Builtin(Builtin::List(_, element_layout)) => list_walk(
env,
layout_ids,
parent,
list,
element_layout,
func,
func_layout,
default,
default_layout,
),
_ => unreachable!("invalid list layout"),
}
}
ListWalkBackwards => {
// List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum
debug_assert_eq!(args.len(), 3);
let (list, list_layout) = load_symbol_and_layout(scope, &args[0]);
let (func, func_layout) = load_symbol_and_layout(scope, &args[1]);
let (default, default_layout) = load_symbol_and_layout(scope, &args[2]);
match list_layout {
Layout::Builtin(Builtin::EmptyList) => default,
Layout::Builtin(Builtin::List(_, element_layout)) => list_walk_backwards(
env,
layout_ids,
parent,
list,
element_layout,
func,
func_layout,
default,
default_layout,
),
_ => unreachable!("invalid list layout"),
}
}
ListWalk => list_walk_help(
env,
layout_ids,
scope,
parent,
args,
crate::llvm::build_list::ListWalk::Walk,
),
ListWalkUntil => list_walk_help(
env,
layout_ids,
scope,
parent,
args,
crate::llvm::build_list::ListWalk::WalkUntil,
),
ListWalkBackwards => list_walk_help(
env,
layout_ids,
scope,
parent,
args,
crate::llvm::build_list::ListWalk::WalkBackwards,
),
ListSum => {
debug_assert_eq!(args.len(), 1);

View file

@ -863,56 +863,47 @@ pub fn list_product<'a, 'ctx, 'env>(
builder.build_load(accum_alloca, "load_final_acum")
}
/// List.walk : List elem, (elem -> accum -> accum), accum -> accum
pub fn list_walk<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
parent: FunctionValue<'ctx>,
list: BasicValueEnum<'ctx>,
element_layout: &Layout<'a>,
func: BasicValueEnum<'ctx>,
func_layout: &Layout<'a>,
default: BasicValueEnum<'ctx>,
default_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
list_walk_generic(
env,
layout_ids,
parent,
list,
element_layout,
func,
func_layout,
default,
default_layout,
&bitcode::LIST_WALK,
)
pub enum ListWalk {
Walk,
WalkBackwards,
WalkUntil,
WalkBackwardsUntil,
}
/// List.walkBackwards : List elem, (elem -> accum -> accum), accum -> accum
pub fn list_walk_backwards<'a, 'ctx, 'env>(
pub fn list_walk_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
scope: &crate::llvm::build::Scope<'a, 'ctx>,
parent: FunctionValue<'ctx>,
list: BasicValueEnum<'ctx>,
element_layout: &Layout<'a>,
func: BasicValueEnum<'ctx>,
func_layout: &Layout<'a>,
default: BasicValueEnum<'ctx>,
default_layout: &Layout<'a>,
args: &[roc_module::symbol::Symbol],
variant: ListWalk,
) -> BasicValueEnum<'ctx> {
list_walk_generic(
env,
layout_ids,
parent,
list,
element_layout,
func,
func_layout,
default,
default_layout,
&bitcode::LIST_WALK_BACKWARDS,
)
use crate::llvm::build::load_symbol_and_layout;
debug_assert_eq!(args.len(), 3);
let (list, list_layout) = load_symbol_and_layout(scope, &args[0]);
let (func, func_layout) = load_symbol_and_layout(scope, &args[1]);
let (default, default_layout) = load_symbol_and_layout(scope, &args[2]);
match list_layout {
Layout::Builtin(Builtin::EmptyList) => default,
Layout::Builtin(Builtin::List(_, element_layout)) => list_walk_generic(
env,
layout_ids,
parent,
list,
element_layout,
func,
func_layout,
default,
default_layout,
variant,
),
_ => unreachable!("invalid list layout"),
}
}
fn list_walk_generic<'a, 'ctx, 'env>(
@ -925,10 +916,17 @@ fn list_walk_generic<'a, 'ctx, 'env>(
func_layout: &Layout<'a>,
default: BasicValueEnum<'ctx>,
default_layout: &Layout<'a>,
zig_function: &str,
variant: ListWalk,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let zig_function = match variant {
ListWalk::Walk => bitcode::LIST_WALK,
ListWalk::WalkBackwards => bitcode::LIST_WALK_BACKWARDS,
ListWalk::WalkUntil => bitcode::LIST_WALK_UNTIL,
ListWalk::WalkBackwardsUntil => todo!(),
};
let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic);
let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128");
@ -961,21 +959,44 @@ fn list_walk_generic<'a, 'ctx, 'env>(
let result_ptr = env.builder.build_alloca(default.get_type(), "result");
call_void_bitcode_fn(
env,
&[
list_i128,
env.builder
.build_bitcast(transform_ptr, u8_ptr, "to_opaque"),
stepper_caller.into(),
env.builder.build_bitcast(default_ptr, u8_ptr, "to_u8_ptr"),
alignment_iv.into(),
element_width.into(),
default_width.into(),
env.builder.build_bitcast(result_ptr, u8_ptr, "to_opaque"),
],
zig_function,
);
match variant {
ListWalk::Walk | ListWalk::WalkBackwards => {
call_void_bitcode_fn(
env,
&[
list_i128,
env.builder
.build_bitcast(transform_ptr, u8_ptr, "to_opaque"),
stepper_caller.into(),
env.builder.build_bitcast(default_ptr, u8_ptr, "to_u8_ptr"),
alignment_iv.into(),
element_width.into(),
default_width.into(),
env.builder.build_bitcast(result_ptr, u8_ptr, "to_opaque"),
],
zig_function,
);
}
ListWalk::WalkUntil | ListWalk::WalkBackwardsUntil => {
let dec_element_fn = build_dec_wrapper(env, layout_ids, element_layout);
call_void_bitcode_fn(
env,
&[
list_i128,
env.builder
.build_bitcast(transform_ptr, u8_ptr, "to_opaque"),
stepper_caller.into(),
env.builder.build_bitcast(default_ptr, u8_ptr, "to_u8_ptr"),
alignment_iv.into(),
element_width.into(),
default_width.into(),
dec_element_fn.as_global_value().as_pointer_value().into(),
env.builder.build_bitcast(result_ptr, u8_ptr, "to_opaque"),
],
zig_function,
);
}
}
env.builder.build_load(result_ptr, "load_result")
}

View file

@ -111,12 +111,16 @@ fn generate_module_doc<'a>(
},
Alias {
name: _,
name,
vars: _,
ann: _,
} =>
// TODO
{
} => {
let entry = DocEntry {
name: name.value.to_string(),
docs: before_comments_or_new_lines.and_then(comments_or_new_lines_to_docs),
};
acc.push(entry);
(acc, None)
}
@ -139,7 +143,9 @@ fn comments_or_new_lines_to_docs<'a>(
docs.push_str(doc_str);
docs.push('\n');
}
Newline | LineComment(_) => {}
Newline | LineComment(_) => {
docs = String::new();
}
}
}
if docs.is_empty() {

View file

@ -32,6 +32,7 @@ pub enum LowLevel {
ListMapWithIndex,
ListKeepIf,
ListWalk,
ListWalkUntil,
ListWalkBackwards,
ListSum,
ListProduct,

View file

@ -915,6 +915,7 @@ define_builtins! {
24 LIST_MAP2: "map2"
25 LIST_MAP3: "map3"
26 LIST_PRODUCT: "product"
27 LIST_WALK_UNTIL: "walkUntil"
}
5 RESULT: "Result" => {
0 RESULT_RESULT: "Result" imported // the Result.Result type alias

View file

@ -655,8 +655,9 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
ListMap3 => arena.alloc_slice_copy(&[owned, owned, owned, irrelevant]),
ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, borrowed]),
ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]),
ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]),
ListWalkBackwards => arena.alloc_slice_copy(&[owned, irrelevant, owned]),
ListWalk | ListWalkUntil | ListWalkBackwards => {
arena.alloc_slice_copy(&[owned, irrelevant, owned])
}
ListSum | ListProduct => arena.alloc_slice_copy(&[borrowed]),
// TODO when we have lists with capacity (if ever)

View file

@ -2839,11 +2839,69 @@ pub fn with_hole<'a>(
variant_var,
name: tag_name,
arguments: args,
..
ext_var,
} => {
use crate::layout::UnionVariant::*;
let arena = env.arena;
let desc = env.subs.get_without_compacting(variant_var);
if let Content::Structure(FlatType::Func(arg_vars, _, ret_var)) = desc.content {
let mut loc_pattern_args = vec![];
let mut loc_expr_args = vec![];
let proc_symbol = env.unique_symbol();
for arg_var in arg_vars {
let arg_symbol = env.unique_symbol();
let loc_pattern =
Located::at_zero(roc_can::pattern::Pattern::Identifier(arg_symbol));
let loc_expr = Located::at_zero(roc_can::expr::Expr::Var(arg_symbol));
loc_pattern_args.push((arg_var, loc_pattern));
loc_expr_args.push((arg_var, loc_expr));
}
let loc_body = Located::at_zero(roc_can::expr::Expr::Tag {
variant_var: ret_var,
name: tag_name,
arguments: loc_expr_args,
ext_var,
});
let inserted = procs.insert_anonymous(
env,
proc_symbol,
variant_var,
loc_pattern_args,
loc_body,
CapturedSymbols::None,
ret_var,
layout_cache,
);
match inserted {
Ok(layout) => {
return Stmt::Let(
assigned,
call_by_pointer(env, procs, proc_symbol, layout),
layout,
hole,
);
}
Err(runtime_error) => {
return Stmt::RuntimeError(env.arena.alloc(format!(
"RuntimeError {} line {} {:?}",
file!(),
line!(),
runtime_error,
)));
}
}
}
let res_variant = crate::layout::union_sorted_tags(env.arena, variant_var, env.subs);
let variant = match res_variant {

View file

@ -4613,7 +4613,6 @@ mod test_reporting {
#[test]
fn type_apply_stray_dot() {
// TODO good message
report_problem_as(
indoc!(
r#"
@ -6052,6 +6051,39 @@ mod test_reporting {
)
}
#[test]
fn applied_tag_function() {
report_problem_as(
indoc!(
r#"
x : List [ Foo Str ]
x = List.map [ 1, 2 ] Foo
x
"#
),
indoc!(
r#"
TYPE MISMATCH
Something is off with the body of the `x` definition:
1 x : List [ Foo Str ]
2 x = List.map [ 1, 2 ] Foo
^^^^^^^^^^^^^^^^^^^^^
This `map` call produces:
List [ Foo Num a ]
But the type annotation on `x` says it should be:
List [ Foo Str ]
"#
),
)
}
#[test]
fn pattern_in_parens_open() {
report_problem_as(

View file

@ -566,6 +566,107 @@ mod solve_expr {
);
}
#[test]
fn applied_tag() {
infer_eq_without_problem(
indoc!(
r#"
List.map [ "a", "b" ] \elem -> Foo elem
"#
),
"List [ Foo Str ]*",
)
}
// Tests (TagUnion, Func)
#[test]
fn applied_tag_function() {
infer_eq_without_problem(
indoc!(
r#"
foo = Foo
foo "hi"
"#
),
"[ Foo Str ]*",
)
}
// Tests (TagUnion, Func)
#[test]
fn applied_tag_function_list_map() {
infer_eq_without_problem(
indoc!(
r#"
List.map [ "a", "b" ] Foo
"#
),
"List [ Foo Str ]*",
)
}
// Tests (TagUnion, Func)
#[test]
fn applied_tag_function_list() {
infer_eq_without_problem(
indoc!(
r#"
[ \x -> Bar x, Foo ]
"#
),
"List (a -> [ Bar a, Foo a ]*)",
)
}
// Tests (Func, TagUnion)
#[test]
fn applied_tag_function_list_other_way() {
infer_eq_without_problem(
indoc!(
r#"
[ Foo, \x -> Bar x ]
"#
),
"List (a -> [ Bar a, Foo a ]*)",
)
}
// Tests (Func, TagUnion)
#[test]
fn applied_tag_function_record() {
infer_eq_without_problem(
indoc!(
r#"
foo = Foo
{
x: [ foo, Foo ],
y: [ foo, \x -> Foo x ],
z: [ foo, \x,y -> Foo x y ]
}
"#
),
"{ x : List [ Foo ]*, y : List (a -> [ Foo a ]*), z : List (b, c -> [ Foo b c ]*) }",
)
}
// Tests (TagUnion, Func)
#[test]
fn applied_tag_function_with_annotation() {
infer_eq_without_problem(
indoc!(
r#"
x : List [ Foo I64 ]
x = List.map [ 1, 2 ] Foo
x
"#
),
"List [ Foo I64 ]",
)
}
#[test]
fn def_2_arg_closure() {
infer_eq(

View file

@ -320,6 +320,32 @@ fn list_walk_substraction() {
assert_evals_to!(r#"List.walk [ 1, 2 ] Num.sub 1"#, 2, i64);
}
#[test]
fn list_walk_until_sum() {
assert_evals_to!(
r#"List.walkUntil [ 1, 2 ] (\a,b -> Continue (a + b)) 0"#,
3,
i64
);
}
#[test]
fn list_walk_until_even_prefix_sum() {
assert_evals_to!(
r#"
helper = \a, b ->
if Num.isEven a then
Continue (a + b)
else
Stop b
List.walkUntil [ 2, 4, 8, 9 ] helper 0"#,
2 + 4 + 8,
i64
);
}
#[test]
fn list_keep_if_empty_list_of_int() {
assert_evals_to!(

View file

@ -3,6 +3,7 @@
use crate::assert_evals_to;
use crate::assert_llvm_evals_to;
use indoc::indoc;
use roc_std::{RocList, RocStr};
#[test]
fn applied_tag_nothing_ir() {
@ -974,3 +975,61 @@ fn newtype_wrapper() {
|x: &i64| *x
);
}
#[test]
fn applied_tag_function() {
assert_evals_to!(
indoc!(
r#"
x : List [ Foo Str ]
x = List.map [ "a", "b" ] Foo
x
"#
),
RocList::from_slice(&[
RocStr::from_slice("a".as_bytes()),
RocStr::from_slice("b".as_bytes())
]),
RocList<RocStr>
);
}
#[test]
fn applied_tag_function_result() {
assert_evals_to!(
indoc!(
r#"
x : List (Result Str *)
x = List.map [ "a", "b" ] Ok
x
"#
),
RocList::from_slice(&[
(1, RocStr::from_slice("a".as_bytes())),
(1, RocStr::from_slice("b".as_bytes()))
]),
RocList<(i64, RocStr)>
);
}
#[test]
fn applied_tag_function_linked_list() {
assert_evals_to!(
indoc!(
r#"
ConsList a : [ Nil, Cons a (ConsList a) ]
x : List (ConsList Str)
x = List.map2 [ "a", "b" ] [ Nil, Cons "c" Nil ] Cons
when List.first x is
Ok (Cons "a" Nil) -> 1
_ -> 0
"#
),
1,
i64
);
}

View file

@ -1,4 +1,6 @@
use roc_collections::all::{get_shared, relative_complement, union, MutMap, SendSet};
use roc_collections::all::{
default_hasher, get_shared, relative_complement, union, MutMap, SendSet,
};
use roc_module::ident::{Lowercase, TagName};
use roc_module::symbol::Symbol;
use roc_types::boolean_algebra::Bool;
@ -1069,6 +1071,12 @@ fn unify_flat_type(
problems
}
}
(TagUnion(tags, ext), Func(args, closure, ret)) if tags.len() == 1 => {
unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, true)
}
(Func(args, closure, ret), TagUnion(tags, ext)) if tags.len() == 1 => {
unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, false)
}
(other1, other2) => mismatch!(
"Trying to unify two flat types that are incompatible: {:?} ~ {:?}",
other1,
@ -1250,3 +1258,54 @@ fn is_recursion_var(subs: &Subs, var: Variable) -> bool {
Content::RecursionVar { .. }
)
}
#[allow(clippy::too_many_arguments, clippy::ptr_arg)]
fn unify_tag_union_and_func(
tags: &MutMap<TagName, Vec<Variable>>,
args: &Vec<Variable>,
subs: &mut Subs,
pool: &mut Pool,
ctx: &Context,
ext: &Variable,
ret: &Variable,
closure: &Variable,
left: bool,
) -> Outcome {
use FlatType::*;
let (tag_name, payload) = tags.iter().next().unwrap();
if payload.is_empty() {
let mut new_tags = MutMap::with_capacity_and_hasher(1, default_hasher());
new_tags.insert(tag_name.clone(), args.to_owned());
let content = Structure(TagUnion(new_tags, *ext));
let new_tag_union_var = fresh(subs, pool, ctx, content);
let problems = if left {
unify_pool(subs, pool, new_tag_union_var, *ret)
} else {
unify_pool(subs, pool, *ret, new_tag_union_var)
};
if problems.is_empty() {
let desc = if left {
subs.get(ctx.second)
} else {
subs.get(ctx.first)
};
subs.union(ctx.first, ctx.second, desc);
}
problems
} else {
mismatch!(
"Trying to unify two flat types that are incompatible: {:?} ~ {:?}",
TagUnion(tags.clone(), *ext),
Func(args.to_owned(), *closure, *ret)
)
}
}