Merge pull request #684 from rtfeldman/big-nested-pattern-match

pattern matching fixes
This commit is contained in:
Richard Feldman 2020-11-15 08:24:31 -05:00 committed by GitHub
commit a55f30e512
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 1060 additions and 662 deletions

View file

@ -431,6 +431,7 @@ pub fn constrain_expr(
match expected {
FromAnnotation(name, arity, _, tipe) => {
let num_branches = branches.len() + 1;
for (index, (loc_cond, loc_body)) in branches.iter().enumerate() {
let cond_con = constrain_expr(
env,
@ -448,7 +449,7 @@ pub fn constrain_expr(
arity,
AnnotationSource::TypedIfBranch {
index: Index::zero_based(index),
num_branches: branches.len(),
num_branches,
},
tipe.clone(),
),
@ -467,7 +468,7 @@ pub fn constrain_expr(
arity,
AnnotationSource::TypedIfBranch {
index: Index::zero_based(branches.len()),
num_branches: branches.len(),
num_branches,
},
tipe.clone(),
),
@ -558,15 +559,12 @@ pub fn constrain_expr(
constraints.push(expr_con);
match &expected {
FromAnnotation(name, arity, _, typ) => {
// record the type of the whole expression in the AST
let ast_con = Eq(
Type::Variable(*expr_var),
expected.clone(),
Category::Storage(std::file!(), std::line!()),
region,
);
constraints.push(ast_con);
FromAnnotation(name, arity, _, _typ) => {
// NOTE deviation from elm.
//
// in elm, `_typ` is used, but because we have this `expr_var` too
// and need to constrain it, this is what works and gives better error messages
let typ = Type::Variable(*expr_var);
for (index, when_branch) in branches.iter().enumerate() {
let pattern_region =
@ -595,6 +593,10 @@ pub fn constrain_expr(
constraints.push(branch_con);
}
constraints.push(Eq(typ, expected, Category::When, region));
return exists(vec![cond_var, *expr_var], And(constraints));
}
_ => {
@ -1119,9 +1121,11 @@ fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint {
name,
..
},
Type::Function(arg_types, _, _),
Type::Function(arg_types, _closure_type, ret_type),
) => {
let expected = annotation_expected;
// NOTE if we ever have problems with the closure, the ignored `_closure_type`
// is probably a good place to start the investigation!
let region = def.loc_expr.region;
let loc_body_expr = &**loc_body;
@ -1135,7 +1139,7 @@ fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint {
let ret_var = *ret_var;
let closure_var = *closure_var;
let closure_ext_var = *closure_ext_var;
let ret_type = Type::Variable(ret_var);
let ret_type = *ret_type.clone();
vars.push(ret_var);
vars.push(closure_var);
@ -1197,12 +1201,15 @@ fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint {
&mut vars,
);
let fn_type = Type::Function(
pattern_types,
Box::new(Type::Variable(closure_var)),
Box::new(ret_type.clone()),
let body_type = FromAnnotation(
def.loc_pattern.clone(),
arguments.len(),
AnnotationSource::TypedBody {
region: annotation.region,
},
ret_type.clone(),
);
let body_type = NoExpectation(ret_type);
let ret_constraint =
constrain_expr(env, loc_body_expr.region, &loc_body_expr.value, body_type);
@ -1219,22 +1226,32 @@ fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint {
defs_constraint,
ret_constraint,
})),
// "the closure's type is equal to expected type"
Eq(fn_type, expected, Category::Lambda, region),
// Store type into AST vars. We use Store so errors aren't reported twice
Store(signature.clone(), *fn_var, std::file!(), std::line!()),
Store(signature, expr_var, std::file!(), std::line!()),
Store(ret_type, ret_var, std::file!(), std::line!()),
closure_constraint,
]),
)
}
_ => constrain_expr(
&env,
def.loc_expr.region,
&def.loc_expr.value,
annotation_expected,
),
_ => {
let expected = annotation_expected;
let ret_constraint =
constrain_expr(env, def.loc_expr.region, &def.loc_expr.value, expected);
And(vec![
Let(Box::new(LetConstraint {
rigid_vars: Vec::new(),
flex_vars: vec![],
def_types: SendMap::default(),
defs_constraint: True,
ret_constraint,
})),
// Store type into AST vars. We use Store so errors aren't reported twice
Store(signature, expr_var, std::file!(), std::line!()),
])
}
}
}
None => {
@ -1440,8 +1457,11 @@ pub fn rec_defs_help(
name,
..
},
Type::Function(arg_types, _, _),
Type::Function(arg_types, _closure_type, ret_type),
) => {
// NOTE if we ever have trouble with closure type unification, the ignored
// `_closure_type` here is a good place to start investigating
let expected = annotation_expected;
let region = def.loc_expr.region;
@ -1456,7 +1476,7 @@ pub fn rec_defs_help(
let ret_var = *ret_var;
let closure_var = *closure_var;
let closure_ext_var = *closure_ext_var;
let ret_type = Type::Variable(ret_var);
let ret_type = *ret_type.clone();
vars.push(ret_var);
vars.push(closure_var);
@ -1523,7 +1543,7 @@ pub fn rec_defs_help(
Box::new(Type::Variable(closure_var)),
Box::new(ret_type.clone()),
);
let body_type = NoExpectation(ret_type);
let body_type = NoExpectation(ret_type.clone());
let expr_con = constrain_expr(
env,
loc_body_expr.region,
@ -1548,6 +1568,7 @@ pub fn rec_defs_help(
// Store type into AST vars. We use Store so errors aren't reported twice
Store(signature.clone(), *fn_var, std::file!(), std::line!()),
Store(signature, expr_var, std::file!(), std::line!()),
Store(ret_type, ret_var, std::file!(), std::line!()),
closure_constraint,
]),
);

View file

@ -752,30 +752,27 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
let mut field_types = Vec::with_capacity_in(num_fields, env.arena);
let mut field_vals = Vec::with_capacity_in(num_fields, env.arena);
for (field_symbol, tag_field_layout) in
arguments.iter().zip(fields[*tag_id as usize].iter())
{
// note field_layout is the layout of the argument.
// tag_field_layout is the layout that the tag will store
// these are different for recursive tag unions
let (val, field_layout) = load_symbol_and_layout(env, scope, field_symbol);
let field_size = tag_field_layout.stack_size(ptr_size);
let tag_field_layouts = fields[*tag_id as usize];
for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) {
let val = load_symbol(env, scope, field_symbol);
// Zero-sized fields have no runtime representation.
// The layout of the struct expects them to be dropped!
if field_size != 0 {
if !tag_field_layout.is_dropped_because_empty() {
let field_type =
basic_type_from_layout(env.arena, env.context, tag_field_layout, ptr_size);
field_types.push(field_type);
if let Layout::RecursivePointer = tag_field_layout {
let ptr = allocate_with_refcount(env, field_layout, val).into();
let ptr = allocate_with_refcount(env, &tag_layout, val);
let ptr = cast_basic_basic(
builder,
ptr,
ptr.into(),
ctx.i64_type().ptr_type(AddressSpace::Generic).into(),
);
field_vals.push(ptr);
} else {
field_vals.push(val);
@ -993,8 +990,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>(
// bytes per element
let bytes_len = len_type.const_int(value_bytes, false);
// TODO fix offset
let offset = (env.ptr_bytes as u64).max(value_bytes);
let offset = crate::llvm::refcounting::refcount_offset(env, layout);
let ptr = {
let len = bytes_len;
@ -1011,7 +1007,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>(
// We must return a pointer to the first element:
let ptr_bytes = env.ptr_bytes;
let int_type = ptr_int(ctx, ptr_bytes);
let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "list_cast_ptr");
let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "allocate_refcount_pti");
let incremented = builder.build_int_add(
ptr_as_int,
ctx.i64_type().const_int(offset, false),
@ -1019,7 +1015,7 @@ pub fn allocate_with_refcount<'a, 'ctx, 'env>(
);
let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic);
let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "list_cast_ptr");
let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "allocate_refcount_itp");
// subtract ptr_size, to access the refcount
let refcount_ptr = builder.build_int_sub(

View file

@ -89,46 +89,7 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>(
RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"),
Union(tags) => {
debug_assert!(!tags.is_empty());
let wrapper_struct = value.into_struct_value();
// read the tag_id
let tag_id = env
.builder
.build_extract_value(wrapper_struct, 0, "read_tag_id")
.unwrap()
.into_int_value();
// next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
let merge_block = env.context.append_basic_block(parent, "decrement_merge");
for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_decrement");
env.builder.position_at_end(block);
for (i, field_layout) in field_layouts.iter().enumerate() {
if field_layout.contains_refcounted() {
let field_ptr = env
.builder
.build_extract_value(wrapper_struct, i as u32, "decrement_struct_field")
.unwrap();
decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout)
}
}
env.builder.build_unconditional_branch(merge_block);
cases.push((env.context.i8_type().const_int(tag_id as u64, false), block));
}
let (_, default_block) = cases.pop().unwrap();
env.builder.build_switch(tag_id, default_block, &cases);
env.builder.position_at_end(merge_block);
build_dec_union(env, layout_ids, tags, value);
}
RecursiveUnion(tags) => {
@ -749,6 +710,7 @@ fn decrement_refcount_help<'a, 'ctx, 'env>(
],
)
.into_struct_value();
let has_overflowed = builder
.build_extract_value(add_with_overflow, 1, "has_overflowed")
.unwrap();
@ -759,6 +721,7 @@ fn decrement_refcount_help<'a, 'ctx, 'env>(
ctx.bool_type().const_int(1 as u64, false),
"has_overflowed",
);
// build blocks
let then_block = ctx.append_basic_block(parent, "then");
let else_block = ctx.append_basic_block(parent, "else");
@ -780,6 +743,7 @@ fn decrement_refcount_help<'a, 'ctx, 'env>(
// build else block
{
builder.position_at_end(else_block);
let max = builder.build_int_compare(
IntPredicate::EQ,
refcount,
@ -903,14 +867,20 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
let wrapper_struct = arg_val.into_struct_value();
// let tag_id_u8 = cast_basic_basic(env.builder, tag_id.into(), env.context.i8_type().into());
// next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
let merge_block = env.context.append_basic_block(parent, "decrement_merge");
for (tag_id, field_layouts) in tags.iter().enumerate() {
// if none of the fields are or contain anything refcounted, just move on
if !field_layouts
.iter()
.any(|x| x.is_refcounted() || x.contains_refcounted())
{
continue;
}
let block = env.context.append_basic_block(parent, "tag_id_decrement");
env.builder.position_at_end(block);
@ -926,18 +896,19 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout {
// a *i64 pointer to the recursive data
// we need to cast this pointer to the appropriate type
let field_ptr = env
// this field has type `*i64`, but is really a pointer to the data we want
let ptr_as_i64_ptr = env
.builder
.build_extract_value(wrapper_struct, i as u32, "decrement_struct_field")
.unwrap();
// recursively decrement
debug_assert!(ptr_as_i64_ptr.is_pointer_value());
// therefore we must cast it to our desired type
let union_type = block_of_memory(env.context, &layout, env.ptr_bytes);
let recursive_field_ptr = cast_basic_basic(
env.builder,
field_ptr,
ptr_as_i64_ptr,
union_type.ptr_type(AddressSpace::Generic).into(),
)
.into_pointer_value();
@ -956,7 +927,7 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
// TODO do this decrement before the recursive call?
// Then the recursive call is potentially TCE'd
decrement_refcount_ptr(env, parent, &layout, field_ptr.into_pointer_value());
decrement_refcount_ptr(env, parent, &layout, recursive_field_ptr);
} else if field_layout.contains_refcounted() {
let field_ptr = env
.builder
@ -977,8 +948,6 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
cases.reverse();
let (_, default_block) = cases.pop().unwrap();
env.builder.position_at_end(before_block);
// read the tag_id
@ -998,7 +967,7 @@ pub fn build_dec_union_help<'a, 'ctx, 'env>(
// switch on it
env.builder
.build_switch(current_tag_id, default_block, &cases);
.build_switch(current_tag_id, merge_block, &cases);
env.builder.position_at_end(merge_block);
@ -1105,10 +1074,18 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
// next, make a jump table for all possible values of the tag_id
let mut cases = Vec::with_capacity_in(tags.len(), env.arena);
let merge_block = env.context.append_basic_block(parent, "decrement_merge");
let merge_block = env.context.append_basic_block(parent, "increment_merge");
for (tag_id, field_layouts) in tags.iter().enumerate() {
let block = env.context.append_basic_block(parent, "tag_id_decrement");
// if none of the fields are or contain anything refcounted, just move on
if !field_layouts
.iter()
.any(|x| x.is_refcounted() || x.contains_refcounted())
{
continue;
}
let block = env.context.append_basic_block(parent, "tag_id_increment");
env.builder.position_at_end(block);
let wrapper_type = basic_type_from_layout(
@ -1123,18 +1100,19 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
for (i, field_layout) in field_layouts.iter().enumerate() {
if let Layout::RecursivePointer = field_layout {
// a *i64 pointer to the recursive data
// we need to cast this pointer to the appropriate type
let field_ptr = env
// this field has type `*i64`, but is really a pointer to the data we want
let ptr_as_i64_ptr = env
.builder
.build_extract_value(wrapper_struct, i as u32, "decrement_struct_field")
.build_extract_value(wrapper_struct, i as u32, "increment_struct_field")
.unwrap();
// recursively increment
debug_assert!(ptr_as_i64_ptr.is_pointer_value());
// therefore we must cast it to our desired type
let union_type = block_of_memory(env.context, &layout, env.ptr_bytes);
let recursive_field_ptr = cast_basic_basic(
env.builder,
field_ptr,
ptr_as_i64_ptr,
union_type.ptr_type(AddressSpace::Generic).into(),
)
.into_pointer_value();
@ -1151,9 +1129,9 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
// Because it's an internal-only function, use the fast calling convention.
call.set_call_convention(FAST_CALL_CONV);
// TODO do this increment before the recursive call?
// TODO do this decrement before the recursive call?
// Then the recursive call is potentially TCE'd
increment_refcount_ptr(env, &layout, field_ptr.into_pointer_value());
increment_refcount_ptr(env, &layout, recursive_field_ptr);
} else if field_layout.contains_refcounted() {
let field_ptr = env
.builder
@ -1169,12 +1147,10 @@ pub fn build_inc_union_help<'a, 'ctx, 'env>(
cases.push((env.context.i8_type().const_int(tag_id as u64, false), block));
}
let (_, default_block) = cases.pop().unwrap();
env.builder.position_at_end(before_block);
env.builder
.build_switch(tag_id_u8.into_int_value(), default_block, &cases);
.build_switch(tag_id_u8.into_int_value(), merge_block, &cases);
env.builder.position_at_end(merge_block);
@ -1221,6 +1197,17 @@ fn get_refcount_ptr<'a, 'ctx, 'env>(
get_refcount_ptr_help(env, layout, ptr_as_int)
}
pub fn refcount_offset<'a, 'ctx, 'env>(env: &Env<'a, 'ctx, 'env>, layout: &Layout<'a>) -> u64 {
let value_bytes = layout.stack_size(env.ptr_bytes) as u64;
match layout {
Layout::Builtin(Builtin::List(_, _)) => env.ptr_bytes as u64,
Layout::Builtin(Builtin::Str) => env.ptr_bytes as u64,
Layout::RecursivePointer | Layout::RecursiveUnion(_) => env.ptr_bytes as u64,
_ => (env.ptr_bytes as u64).max(value_bytes),
}
}
fn get_refcount_ptr_help<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout: &Layout<'a>,
@ -1229,12 +1216,7 @@ fn get_refcount_ptr_help<'a, 'ctx, 'env>(
let builder = env.builder;
let ctx = env.context;
let value_bytes = layout.stack_size(env.ptr_bytes) as u64;
let offset = match layout {
Layout::Builtin(Builtin::List(_, _)) => env.ptr_bytes as u64,
Layout::Builtin(Builtin::Str) => env.ptr_bytes as u64,
_ => (env.ptr_bytes as u64).max(value_bytes),
};
let offset = refcount_offset(env, layout);
// pointer to usize
let refcount_type = ptr_int(ctx, env.ptr_bytes);

View file

@ -1277,6 +1277,108 @@ mod gen_primitives {
);
}
#[test]
#[ignore]
fn rbtree_balance_inc_dec() {
// TODO does not define a variable correctly, but all is well with the type signature
assert_non_opt_evals_to!(
indoc!(
r#"
app Test provides [ main ] imports []
NodeColor : [ Red, Black ]
Dict k : [ Node NodeColor k (Dict k) (Dict k), Empty ]
# balance : NodeColor, k, Dict k, Dict k -> Dict k
balance = \color, key, left, right ->
when right is
Node Red rK rLeft rRight ->
when left is
Node Red _ _ _ ->
Node
Red
key
Empty
Empty
_ ->
Node color rK (Node Red key left rLeft) rRight
_ ->
Empty
main : Dict Int
main =
balance Red 0 Empty Empty
"#
),
0,
i64
);
}
#[test]
fn rbtree_balance_3() {
assert_non_opt_evals_to!(
indoc!(
r#"
app Test provides [ main ] imports []
Dict k : [ Node k (Dict k) (Dict k), Empty ]
balance : k, Dict k -> Dict k
balance = \key, left ->
Node key left Empty
main : Dict Int
main =
balance 0 Empty
"#
),
1,
i64
);
}
#[test]
fn rbtree_balance_2() {
assert_non_opt_evals_to!(
indoc!(
r#"
app Test provides [ main ] imports []
NodeColor : [ Red, Black ]
Dict k : [ Node NodeColor k (Dict k), Empty ]
balance : NodeColor, k, Dict k, Dict k -> Dict k
balance = \color, key, left, right ->
when right is
Node Red rK _ ->
when left is
Node Red _ _ ->
Node
Red
key
Empty
_ ->
Node color rK (Node Red key left )
_ ->
Empty
main : Dict Int
main =
balance Red 0 Empty Empty
"#
),
0,
i64
);
}
#[test]
#[ignore]
fn rbtree_balance() {
@ -1289,18 +1391,38 @@ mod gen_primitives {
Dict k v : [ Node NodeColor k v (Dict k v) (Dict k v), Empty ]
Key k : Num k
balance : NodeColor, k, v, Dict k v, Dict k v -> Dict k v
balance = \color, key, value, left, right ->
when right is
Node Red lK lV (Node Red llK llV llLeft llRight) lRight -> Empty
Empty -> Empty
Node Red rK rV rLeft rRight ->
when left is
Node Red lK lV lLeft lRight ->
Node
Red
key
value
(Node Black lK lV lLeft lRight)
(Node Black rK rV rLeft rRight)
_ ->
Node color rK rV (Node Red key value left rLeft) rRight
main : Dict Int {}
_ ->
when left is
Node Red lK lV (Node Red llK llV llLeft llRight) lRight ->
Node
Red
lK
lV
(Node Black llK llV llLeft llRight)
(Node Black key value lRight right)
_ ->
Node color key value left right
main : Dict Int Int
main =
balance Red 0 {} Empty Empty
balance Red 0 0 Empty Empty
"#
),
1,
@ -1310,6 +1432,34 @@ mod gen_primitives {
#[test]
#[ignore]
fn linked_list_guarded_double_pattern_match() {
// the important part here is that the first case (with the nested Cons) does not match
// TODO this also has undefined behavior
assert_non_opt_evals_to!(
indoc!(
r#"
app Test provides [ main ] imports []
ConsList a : [ Cons a (ConsList a), Nil ]
balance : ConsList Int -> Int
balance = \right ->
when right is
Cons 1 (Cons 1 _) -> 3
_ -> 3
main : Int
main =
when balance Nil is
_ -> 3
"#
),
3,
i64
);
}
#[test]
fn linked_list_double_pattern_match() {
assert_non_opt_evals_to!(
indoc!(

View file

@ -11,6 +11,7 @@ use roc_module::symbol::Symbol;
/// COMPILE CASES
type Label = u64;
const RECORD_TAG_NAME: &str = "#Record";
/// Users of this module will mainly interact with this function. It takes
/// some normal branches and gives out a decision tree that has "labels" at all
@ -189,7 +190,7 @@ fn to_decision_tree(raw_branches: Vec<Branch>) -> DecisionTree {
fn is_complete(tests: &[Test]) -> bool {
let length = tests.len();
debug_assert!(length > 0);
match tests.get(length - 1) {
match tests.last() {
None => unreachable!("should never happen"),
Some(v) => match v {
Test::IsCtor { union, .. } => length == union.alternatives.len(),
@ -395,7 +396,7 @@ fn test_at_path<'a>(selected_path: &Path, branch: &Branch<'a>, all_tests: &mut V
render_as: RenderAs::Tag,
alternatives: vec![Ctor {
tag_id: TagId(0),
name: TagName::Global("#Record".into()),
name: TagName::Global(RECORD_TAG_NAME.into()),
arity: destructs.len(),
}],
};
@ -418,7 +419,7 @@ fn test_at_path<'a>(selected_path: &Path, branch: &Branch<'a>, all_tests: &mut V
all_tests.push(IsCtor {
tag_id: 0,
tag_name: TagName::Global("#Record".into()),
tag_name: TagName::Global(RECORD_TAG_NAME.into()),
union,
arguments,
});
@ -538,7 +539,7 @@ fn to_relevant_branch_help<'a>(
tag_id,
..
} => {
debug_assert!(test_name == &TagName::Global("#Record".into()));
debug_assert!(test_name == &TagName::Global(RECORD_TAG_NAME.into()));
let sub_positions = destructs.into_iter().enumerate().map(|(index, destruct)| {
let pattern = match destruct.typ {
DestructType::Guard(guard) => guard.clone(),
@ -934,32 +935,59 @@ pub fn optimize_when<'a>(
)
}
fn path_to_expr_help<'a>(
env: &mut Env<'a, '_>,
mut symbol: Symbol,
mut path: &Path,
mut layout: Layout<'a>,
) -> (Symbol, StoresVec<'a>, Layout<'a>) {
let mut stores = bumpalo::collections::Vec::new_in(env.arena);
#[derive(Debug)]
struct PathInstruction {
index: u64,
tag_id: u8,
}
fn reverse_path(mut path: &Path) -> Vec<PathInstruction> {
let mut result = Vec::new();
loop {
match path {
Path::Unbox(unboxed) => {
path = unboxed;
Path::Unbox(nested) => {
path = nested;
}
Path::Empty => break,
Path::Index {
index,
tag_id,
path: nested,
} => match Wrapped::opt_from_layout(&layout) {
} => {
result.push(PathInstruction {
index: *index,
tag_id: *tag_id,
});
path = nested;
}
}
}
result.reverse();
result
}
fn path_to_expr_help<'a>(
env: &mut Env<'a, '_>,
mut symbol: Symbol,
path: &Path,
mut layout: Layout<'a>,
) -> (Symbol, StoresVec<'a>, Layout<'a>) {
let mut stores = bumpalo::collections::Vec::new_in(env.arena);
let instructions = reverse_path(path);
let mut it = instructions.iter().peekable();
while let Some(PathInstruction { index, tag_id }) = it.next() {
match Wrapped::opt_from_layout(&layout) {
None => {
// this MUST be an index into a single-element (hence unwrapped) record
debug_assert_eq!(*index, 0);
debug_assert_eq!(*tag_id, 0);
debug_assert_eq!(**nested, Path::Empty);
debug_assert!(it.peek().is_none());
let field_layouts = vec![layout.clone()];
@ -981,10 +1009,10 @@ fn path_to_expr_help<'a>(
Some(wrapped) => {
let field_layouts = match &layout {
Layout::Union(layouts) | Layout::RecursiveUnion(layouts) => {
layouts[*tag_id as usize].to_vec()
layouts[*tag_id as usize]
}
Layout::Struct(layouts) => layouts.to_vec(),
other => vec![other.clone()],
Layout::Struct(layouts) => layouts,
other => env.arena.alloc([other.clone()]),
};
debug_assert!(*index < field_layouts.len() as u64);
@ -996,7 +1024,7 @@ fn path_to_expr_help<'a>(
let inner_expr = Expr::AccessAtIndex {
index: *index,
field_layouts: env.arena.alloc(field_layouts),
field_layouts,
structure: symbol,
wrapped,
};
@ -1005,9 +1033,7 @@ fn path_to_expr_help<'a>(
stores.push((symbol, inner_layout.clone(), inner_expr));
layout = inner_layout;
path = nested;
}
},
}
}
@ -1143,80 +1169,19 @@ fn test_to_equality<'a>(
}
}
// TODO procs and layout are currently unused, but potentially required
// for defining optional fields?
// if not, do remove
#[allow(clippy::too_many_arguments, clippy::needless_collect)]
fn decide_to_branching<'a>(
type Tests<'a> = std::vec::Vec<(
bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>,
Symbol,
Symbol,
Layout<'a>,
)>;
fn stores_and_condition<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
cond_symbol: Symbol,
cond_layout: Layout<'a>,
ret_layout: Layout<'a>,
decider: Decider<'a, Choice<'a>>,
jumps: &Vec<(u64, Stmt<'a>)>,
) -> Stmt<'a> {
use Choice::*;
use Decider::*;
match decider {
Leaf(Jump(label)) => {
// we currently inline the jumps: does fewer jumps but produces a larger artifact
let (_, expr) = jumps
.iter()
.find(|(l, _)| l == &label)
.expect("jump not in list of jumps");
expr.clone()
}
Leaf(Inline(expr)) => expr,
Chain {
test_chain,
success,
failure,
} => {
// generate a switch based on the test chain
let pass_expr = decide_to_branching(
env,
procs,
layout_cache,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
*success,
jumps,
);
let fail_expr = decide_to_branching(
env,
procs,
layout_cache,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
*failure,
jumps,
);
let fail = &*env.arena.alloc(fail_expr);
let pass = &*env.arena.alloc(pass_expr);
let branching_symbol = env.unique_symbol();
let branching_layout = Layout::Builtin(Builtin::Int1);
let mut cond = Stmt::Cond {
cond_symbol,
cond_layout: cond_layout.clone(),
branching_symbol,
branching_layout,
pass,
fail,
ret_layout,
};
let true_symbol = env.unique_symbol();
cond_layout: &Layout<'a>,
test_chain: Vec<(Path, Test<'a>)>,
) -> (Tests<'a>, Option<(Symbol, JoinPointId, Stmt<'a>)>) {
let mut tests = Vec::with_capacity(test_chain.len());
let mut guard = None;
@ -1255,34 +1220,30 @@ fn decide_to_branching<'a>(
}
}
let mut current_symbol = branching_symbol;
// TODO There must be some way to remove this iterator/loop
let nr = (tests.len() as i64) - 1 + (guard.is_some() as i64);
(tests, guard)
}
fn compile_guard<'a>(
env: &mut Env<'a, '_>,
ret_layout: Layout<'a>,
id: JoinPointId,
stmt: &'a Stmt<'a>,
fail: &'a Stmt<'a>,
mut cond: Stmt<'a>,
) -> Stmt<'a> {
// the guard is the final thing that we check, so needs to be layered on first!
let test_symbol = env.unique_symbol();
let arena = env.arena;
let accum_symbols = std::iter::once(true_symbol)
.chain((0..nr).map(|_| env.unique_symbol()))
.rev()
.collect::<Vec<_>>();
let mut accum_it = accum_symbols.into_iter();
// the guard is the final thing that we check, so needs to be layered on first!
if let Some((_, id, stmt)) = guard {
let accum = accum_it.next().unwrap();
let test_symbol = env.unique_symbol();
let and_expr = Expr::RunLowLevel(LowLevel::And, arena.alloc([test_symbol, accum]));
// write to the branching symbol
cond = Stmt::Let(
current_symbol,
and_expr,
Layout::Builtin(Builtin::Int1),
arena.alloc(cond),
);
cond = Stmt::Cond {
cond_symbol: test_symbol,
cond_layout: Layout::Builtin(Builtin::Int1),
branching_symbol: test_symbol,
branching_layout: Layout::Builtin(Builtin::Int1),
pass: arena.alloc(cond),
fail,
ret_layout,
};
// calculate the guard value
let param = Param {
@ -1290,34 +1251,39 @@ fn decide_to_branching<'a>(
layout: Layout::Builtin(Builtin::Int1),
borrow: false,
};
cond = Stmt::Join {
Stmt::Join {
id,
parameters: arena.alloc([param]),
remainder: arena.alloc(stmt),
remainder: stmt,
continuation: arena.alloc(cond),
};
// load all the variables (the guard might need them);
current_symbol = accum;
}
}
for ((new_stores, lhs, rhs, _layout), accum) in tests.into_iter().rev().zip(accum_it) {
fn compile_test<'a>(
env: &mut Env<'a, '_>,
ret_layout: Layout<'a>,
stores: bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>,
lhs: Symbol,
rhs: Symbol,
fail: &'a Stmt<'a>,
mut cond: Stmt<'a>,
) -> Stmt<'a> {
// if test_symbol then cond else fail
let test_symbol = env.unique_symbol();
let test = Expr::RunLowLevel(
LowLevel::Eq,
bumpalo::vec![in arena; lhs, rhs].into_bump_slice(),
);
let arena = env.arena;
let and_expr = Expr::RunLowLevel(LowLevel::And, arena.alloc([test_symbol, accum]));
cond = Stmt::Cond {
cond_symbol: test_symbol,
cond_layout: Layout::Builtin(Builtin::Int1),
branching_symbol: test_symbol,
branching_layout: Layout::Builtin(Builtin::Int1),
pass: arena.alloc(cond),
fail,
ret_layout,
};
// write to the branching symbol
cond = Stmt::Let(
current_symbol,
and_expr,
Layout::Builtin(Builtin::Int1),
arena.alloc(cond),
);
let test = Expr::RunLowLevel(LowLevel::Eq, arena.alloc([lhs, rhs]));
// write to the test symbol
cond = Stmt::Let(
@ -1328,21 +1294,117 @@ fn decide_to_branching<'a>(
);
// stores are in top-to-bottom order, so we have to add them in reverse
for (symbol, layout, expr) in new_stores.into_iter().rev() {
for (symbol, layout, expr) in stores.into_iter().rev() {
cond = Stmt::Let(symbol, expr, layout, arena.alloc(cond));
}
current_symbol = accum;
cond
}
cond = Stmt::Let(
true_symbol,
Expr::Literal(Literal::Bool(true)),
Layout::Builtin(Builtin::Int1),
arena.alloc(cond),
fn compile_tests<'a>(
env: &mut Env<'a, '_>,
ret_layout: Layout<'a>,
tests: Tests<'a>,
opt_guard: Option<(Symbol, JoinPointId, Stmt<'a>)>,
fail: &'a Stmt<'a>,
mut cond: Stmt<'a>,
) -> Stmt<'a> {
let arena = env.arena;
// the guard is the final thing that we check, so needs to be layered on first!
if let Some((_, id, stmt)) = opt_guard {
cond = compile_guard(env, ret_layout.clone(), id, arena.alloc(stmt), fail, cond);
}
for (new_stores, lhs, rhs, _layout) in tests.into_iter().rev() {
cond = compile_test(env, ret_layout.clone(), new_stores, lhs, rhs, fail, cond);
}
cond
}
// TODO procs and layout are currently unused, but potentially required
// for defining optional fields?
// if not, do remove
#[allow(clippy::too_many_arguments, clippy::needless_collect)]
fn decide_to_branching<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
cond_symbol: Symbol,
cond_layout: Layout<'a>,
ret_layout: Layout<'a>,
decider: Decider<'a, Choice<'a>>,
jumps: &Vec<(u64, Stmt<'a>)>,
) -> Stmt<'a> {
use Choice::*;
use Decider::*;
let arena = env.arena;
match decider {
Leaf(Jump(label)) => {
// we currently inline the jumps: does fewer jumps but produces a larger artifact
let (_, expr) = jumps
.iter()
.find(|(l, _)| l == &label)
.expect("jump not in list of jumps");
expr.clone()
}
Leaf(Inline(expr)) => expr,
Chain {
test_chain,
success,
failure,
} => {
// generate a (nested) if-then-else
let pass_expr = decide_to_branching(
env,
procs,
layout_cache,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
*success,
jumps,
);
cond
let fail_expr = decide_to_branching(
env,
procs,
layout_cache,
cond_symbol,
cond_layout.clone(),
ret_layout.clone(),
*failure,
jumps,
);
let (tests, guard) = stores_and_condition(env, cond_symbol, &cond_layout, test_chain);
let number_of_tests = tests.len() as i64 + guard.is_some() as i64;
debug_assert!(number_of_tests > 0);
let fail = env.arena.alloc(fail_expr);
if number_of_tests == 1 {
// if there is just one test, compile to a simple if-then-else
compile_tests(env, ret_layout, tests, guard, fail, pass_expr)
} else {
// otherwise, we use a join point so the code for the `else` case
// is only generated once.
let fail_jp_id = JoinPointId(env.unique_symbol());
let jump = arena.alloc(Stmt::Jump(fail_jp_id, &[]));
let test_stmt = compile_tests(env, ret_layout, tests, guard, jump, pass_expr);
Stmt::Join {
id: fail_jp_id,
parameters: &[],
continuation: fail,
remainder: arena.alloc(test_stmt),
}
}
}
FanOut {
path,

View file

@ -1134,7 +1134,6 @@ impl<'a> Stmt<'a> {
alloc.intersperse(
vec![
remainder.to_doc(alloc),
alloc
.text("joinpoint ")
.append(join_point_to_doc(alloc, *id))
@ -1142,6 +1141,8 @@ impl<'a> Stmt<'a> {
.append(alloc.intersperse(it, alloc.space()))
.append(":"),
continuation.to_doc(alloc).indent(4),
alloc.text("in"),
remainder.to_doc(alloc),
],
alloc.hardline(),
)
@ -4397,6 +4398,7 @@ fn store_pattern<'a>(
field_layouts: arg_layouts.clone().into_bump_slice(),
structure: outer_symbol,
};
match argument {
Identifier(symbol) => {
// store immediately in the given symbol

View file

@ -435,6 +435,7 @@ impl<'a> Layout<'a> {
match self {
Layout::Builtin(Builtin::List(_, _)) => true,
Layout::RecursiveUnion(_) => true,
Layout::RecursivePointer => true,
_ => false,
}
}

View file

@ -116,11 +116,17 @@ mod test_mono {
let the_same = result == expected;
if !the_same {
println!("{}", result);
let expected_lines = expected.split("\n").collect::<Vec<&str>>();
let result_lines = result.split("\n").collect::<Vec<&str>>();
for line in &result_lines {
if !line.is_empty() {
println!(" {}", line);
} else {
println!("");
}
}
assert_eq!(expected_lines, result_lines);
assert_eq!(0, 1);
}
@ -183,15 +189,13 @@ mod test_mono {
indoc!(
r#"
procedure Test.0 ():
let Test.10 = 0i64;
let Test.11 = 3i64;
let Test.2 = Just Test.10 Test.11;
let Test.6 = true;
let Test.7 = 0i64;
let Test.8 = Index 0 Test.2;
let Test.9 = lowlevel Eq Test.7 Test.8;
let Test.5 = lowlevel And Test.9 Test.6;
if Test.5 then
let Test.8 = 0i64;
let Test.9 = 3i64;
let Test.2 = Just Test.8 Test.9;
let Test.5 = 0i64;
let Test.6 = Index 0 Test.2;
let Test.7 = lowlevel Eq Test.5 Test.6;
if Test.7 then
let Test.1 = Index 1 Test.2;
ret Test.1;
else
@ -310,29 +314,27 @@ mod test_mono {
indoc!(
r#"
procedure Num.32 (#Attr.2, #Attr.3):
let Test.19 = 0i64;
let Test.15 = lowlevel NotEq #Attr.3 Test.19;
if Test.15 then
let Test.17 = 1i64;
let Test.18 = lowlevel NumDivUnchecked #Attr.2 #Attr.3;
let Test.16 = Ok Test.17 Test.18;
ret Test.16;
let Test.17 = 0i64;
let Test.13 = lowlevel NotEq #Attr.3 Test.17;
if Test.13 then
let Test.15 = 1i64;
let Test.16 = lowlevel NumDivUnchecked #Attr.2 #Attr.3;
let Test.14 = Ok Test.15 Test.16;
ret Test.14;
else
let Test.13 = 0i64;
let Test.14 = Struct {};
let Test.12 = Err Test.13 Test.14;
ret Test.12;
let Test.11 = 0i64;
let Test.12 = Struct {};
let Test.10 = Err Test.11 Test.12;
ret Test.10;
procedure Test.0 ():
let Test.10 = 1000i64;
let Test.11 = 10i64;
let Test.2 = CallByName Num.32 Test.10 Test.11;
let Test.6 = true;
let Test.7 = 1i64;
let Test.8 = Index 0 Test.2;
let Test.9 = lowlevel Eq Test.7 Test.8;
let Test.5 = lowlevel And Test.9 Test.6;
if Test.5 then
let Test.8 = 1000i64;
let Test.9 = 10i64;
let Test.2 = CallByName Num.32 Test.8 Test.9;
let Test.5 = 1i64;
let Test.6 = Index 0 Test.2;
let Test.7 = lowlevel Eq Test.5 Test.6;
if Test.7 then
let Test.1 = Index 1 Test.2;
ret Test.1;
else
@ -386,15 +388,13 @@ mod test_mono {
ret Test.5;
procedure Test.0 ():
let Test.12 = 0i64;
let Test.13 = 41i64;
let Test.1 = Just Test.12 Test.13;
let Test.8 = true;
let Test.9 = 0i64;
let Test.10 = Index 0 Test.1;
let Test.11 = lowlevel Eq Test.9 Test.10;
let Test.7 = lowlevel And Test.11 Test.8;
if Test.7 then
let Test.10 = 0i64;
let Test.11 = 41i64;
let Test.1 = Just Test.10 Test.11;
let Test.7 = 0i64;
let Test.8 = Index 0 Test.1;
let Test.9 = lowlevel Eq Test.7 Test.8;
if Test.9 then
let Test.2 = Index 1 Test.1;
let Test.4 = 1i64;
let Test.3 = CallByName Num.14 Test.2 Test.4;
@ -442,20 +442,24 @@ mod test_mono {
r#"
procedure Test.1 (Test.2):
let Test.5 = 2i64;
let Test.11 = true;
let Test.12 = 2i64;
let Test.15 = lowlevel Eq Test.12 Test.5;
let Test.13 = lowlevel And Test.15 Test.11;
let Test.8 = false;
jump Test.7 Test.8;
joinpoint Test.7 Test.14:
let Test.10 = lowlevel And Test.14 Test.13;
if Test.10 then
joinpoint Test.11:
let Test.9 = 0i64;
ret Test.9;
in
let Test.10 = 2i64;
let Test.13 = lowlevel Eq Test.10 Test.5;
if Test.13 then
joinpoint Test.7 Test.12:
if Test.12 then
let Test.6 = 42i64;
ret Test.6;
else
let Test.9 = 0i64;
ret Test.9;
jump Test.11;
in
let Test.8 = false;
jump Test.7 Test.8;
else
jump Test.11;
procedure Test.0 ():
let Test.4 = Struct {};
@ -511,30 +515,33 @@ mod test_mono {
ret Test.6;
procedure Test.0 ():
let Test.17 = 0i64;
let Test.19 = 0i64;
let Test.21 = 0i64;
let Test.22 = 41i64;
let Test.20 = Just Test.21 Test.22;
let Test.2 = Just Test.19 Test.20;
let Test.10 = true;
let Test.20 = 41i64;
let Test.18 = Just Test.19 Test.20;
let Test.2 = Just Test.17 Test.18;
joinpoint Test.14:
let Test.8 = 1i64;
ret Test.8;
in
let Test.9 = Index 1 Test.2;
let Test.10 = 0i64;
let Test.11 = Index 0 Test.9;
let Test.16 = lowlevel Eq Test.10 Test.11;
if Test.16 then
let Test.12 = 0i64;
let Test.11 = Index 1 Test.2;
let Test.13 = Index 0 Test.11;
let Test.18 = lowlevel Eq Test.12 Test.13;
let Test.16 = lowlevel And Test.18 Test.10;
let Test.14 = 0i64;
let Test.15 = Index 0 Test.2;
let Test.17 = lowlevel Eq Test.14 Test.15;
let Test.9 = lowlevel And Test.17 Test.16;
if Test.9 then
let Test.13 = Index 0 Test.2;
let Test.15 = lowlevel Eq Test.12 Test.13;
if Test.15 then
let Test.7 = Index 1 Test.2;
let Test.3 = Index 1 Test.7;
let Test.5 = 1i64;
let Test.4 = CallByName Num.14 Test.3 Test.5;
ret Test.4;
else
let Test.8 = 1i64;
ret Test.8;
jump Test.14;
else
jump Test.14;
"#
),
)
@ -555,26 +562,29 @@ mod test_mono {
ret Test.6;
procedure Test.0 ():
let Test.16 = 2i64;
let Test.17 = 3i64;
let Test.3 = Struct {Test.16, Test.17};
let Test.8 = true;
let Test.10 = 4i64;
let Test.9 = Index 0 Test.3;
let Test.15 = lowlevel Eq Test.10 Test.9;
let Test.13 = lowlevel And Test.15 Test.8;
let Test.12 = 3i64;
let Test.11 = Index 1 Test.3;
let Test.14 = lowlevel Eq Test.12 Test.11;
let Test.7 = lowlevel And Test.14 Test.13;
if Test.7 then
let Test.4 = 9i64;
ret Test.4;
else
let Test.14 = 2i64;
let Test.15 = 3i64;
let Test.3 = Struct {Test.14, Test.15};
joinpoint Test.11:
let Test.1 = Index 0 Test.3;
let Test.2 = Index 1 Test.3;
let Test.5 = CallByName Num.14 Test.1 Test.2;
ret Test.5;
in
let Test.7 = Index 0 Test.3;
let Test.8 = 4i64;
let Test.13 = lowlevel Eq Test.8 Test.7;
if Test.13 then
let Test.9 = Index 1 Test.3;
let Test.10 = 3i64;
let Test.12 = lowlevel Eq Test.10 Test.9;
if Test.12 then
let Test.4 = 9i64;
ret Test.4;
else
jump Test.11;
else
jump Test.11;
"#
),
)
@ -698,6 +708,9 @@ mod test_mono {
r#"
procedure Test.1 (Test.4):
let Test.2 = 0u8;
joinpoint Test.8 Test.3:
ret Test.3;
in
switch Test.2:
case 1:
let Test.9 = 1i64;
@ -711,8 +724,6 @@ mod test_mono {
let Test.11 = 3i64;
jump Test.8 Test.11;
joinpoint Test.8 Test.3:
ret Test.3;
procedure Test.0 ():
let Test.6 = Struct {};
@ -798,21 +809,20 @@ mod test_mono {
indoc!(
r#"
procedure Test.1 (Test.4):
let Test.22 = 1i64;
let Test.23 = 2i64;
let Test.2 = Ok Test.22 Test.23;
let Test.18 = true;
let Test.19 = 1i64;
let Test.20 = Index 0 Test.2;
let Test.21 = lowlevel Eq Test.19 Test.20;
let Test.17 = lowlevel And Test.21 Test.18;
let Test.18 = 1i64;
let Test.19 = 2i64;
let Test.2 = Ok Test.18 Test.19;
joinpoint Test.8 Test.3:
ret Test.3;
in
let Test.15 = 1i64;
let Test.16 = Index 0 Test.2;
let Test.17 = lowlevel Eq Test.15 Test.16;
if Test.17 then
let Test.13 = true;
let Test.15 = 3i64;
let Test.14 = Index 1 Test.2;
let Test.16 = lowlevel Eq Test.15 Test.14;
let Test.12 = lowlevel And Test.16 Test.13;
if Test.12 then
let Test.12 = Index 1 Test.2;
let Test.13 = 3i64;
let Test.14 = lowlevel Eq Test.13 Test.12;
if Test.14 then
let Test.9 = 1i64;
jump Test.8 Test.9;
else
@ -821,8 +831,6 @@ mod test_mono {
else
let Test.11 = 3i64;
jump Test.8 Test.11;
joinpoint Test.8 Test.3:
ret Test.3;
procedure Test.0 ():
let Test.6 = Struct {};
@ -901,18 +909,17 @@ mod test_mono {
procedure Test.1 (Test.3):
let Test.6 = 10i64;
let Test.14 = true;
let Test.10 = 5i64;
let Test.9 = CallByName Bool.5 Test.6 Test.10;
jump Test.8 Test.9;
joinpoint Test.8 Test.15:
let Test.13 = lowlevel And Test.15 Test.14;
joinpoint Test.8 Test.13:
if Test.13 then
let Test.7 = 0i64;
ret Test.7;
else
let Test.12 = 42i64;
ret Test.12;
in
let Test.10 = 5i64;
let Test.9 = CallByName Bool.5 Test.6 Test.10;
jump Test.8 Test.9;
procedure Test.0 ():
let Test.5 = Struct {};
@ -977,11 +984,9 @@ mod test_mono {
r#"
procedure Test.0 ():
let Test.2 = 0i64;
let Test.6 = true;
let Test.7 = 1i64;
let Test.8 = lowlevel Eq Test.7 Test.2;
let Test.5 = lowlevel And Test.8 Test.6;
if Test.5 then
let Test.5 = 1i64;
let Test.6 = lowlevel Eq Test.5 Test.2;
if Test.6 then
let Test.3 = 12i64;
ret Test.3;
else
@ -1195,6 +1200,71 @@ mod test_mono {
),
indoc!(
r#"
procedure List.3 (#Attr.2, #Attr.3):
let Test.38 = lowlevel ListLen #Attr.2;
let Test.34 = lowlevel NumLt #Attr.3 Test.38;
if Test.34 then
let Test.36 = 1i64;
let Test.37 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
let Test.35 = Ok Test.36 Test.37;
ret Test.35;
else
let Test.32 = 0i64;
let Test.33 = Struct {};
let Test.31 = Err Test.32 Test.33;
ret Test.31;
procedure List.4 (#Attr.2, #Attr.3, #Attr.4):
let Test.14 = lowlevel ListLen #Attr.2;
let Test.12 = lowlevel NumLt #Attr.3 Test.14;
if Test.12 then
let Test.13 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4;
ret Test.13;
else
ret #Attr.2;
procedure Test.1 (Test.2):
let Test.39 = 0i64;
let Test.28 = CallByName List.3 Test.2 Test.39;
let Test.30 = 0i64;
let Test.29 = CallByName List.3 Test.2 Test.30;
let Test.7 = Struct {Test.28, Test.29};
joinpoint Test.25:
let Test.18 = Array [];
ret Test.18;
in
let Test.19 = Index 0 Test.7;
let Test.20 = 1i64;
let Test.21 = Index 0 Test.19;
let Test.27 = lowlevel Eq Test.20 Test.21;
if Test.27 then
let Test.22 = Index 1 Test.7;
let Test.23 = 1i64;
let Test.24 = Index 0 Test.22;
let Test.26 = lowlevel Eq Test.23 Test.24;
if Test.26 then
let Test.17 = Index 0 Test.7;
let Test.3 = Index 1 Test.17;
let Test.16 = Index 1 Test.7;
let Test.4 = Index 1 Test.16;
let Test.15 = 0i64;
let Test.9 = CallByName List.4 Test.2 Test.15 Test.4;
let Test.10 = 0i64;
let Test.8 = CallByName List.4 Test.9 Test.10 Test.3;
ret Test.8;
else
dec Test.2;
jump Test.25;
else
dec Test.2;
jump Test.25;
procedure Test.0 ():
let Test.40 = 1i64;
let Test.41 = 2i64;
let Test.6 = Array [Test.40, Test.41];
let Test.5 = CallByName Test.1 Test.6;
ret Test.5;
"#
),
)
@ -1226,18 +1296,18 @@ mod test_mono {
indoc!(
r#"
procedure List.3 (#Attr.2, #Attr.3):
let Test.40 = lowlevel ListLen #Attr.2;
let Test.36 = lowlevel NumLt #Attr.3 Test.40;
if Test.36 then
let Test.38 = 1i64;
let Test.39 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
let Test.37 = Ok Test.38 Test.39;
ret Test.37;
let Test.38 = lowlevel ListLen #Attr.2;
let Test.34 = lowlevel NumLt #Attr.3 Test.38;
if Test.34 then
let Test.36 = 1i64;
let Test.37 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
let Test.35 = Ok Test.36 Test.37;
ret Test.35;
else
let Test.34 = 0i64;
let Test.35 = Struct {};
let Test.33 = Err Test.34 Test.35;
ret Test.33;
let Test.32 = 0i64;
let Test.33 = Struct {};
let Test.31 = Err Test.32 Test.33;
ret Test.31;
procedure List.4 (#Attr.2, #Attr.3, #Attr.4):
let Test.14 = lowlevel ListLen #Attr.2;
@ -1249,23 +1319,25 @@ mod test_mono {
ret #Attr.2;
procedure Test.1 (Test.2):
let Test.41 = 0i64;
let Test.30 = CallByName List.3 Test.2 Test.41;
let Test.32 = 0i64;
let Test.31 = CallByName List.3 Test.2 Test.32;
let Test.7 = Struct {Test.30, Test.31};
let Test.20 = true;
let Test.22 = 1i64;
let Test.21 = Index 0 Test.7;
let Test.23 = Index 0 Test.21;
let Test.29 = lowlevel Eq Test.22 Test.23;
let Test.27 = lowlevel And Test.29 Test.20;
let Test.25 = 1i64;
let Test.24 = Index 1 Test.7;
let Test.26 = Index 0 Test.24;
let Test.28 = lowlevel Eq Test.25 Test.26;
let Test.19 = lowlevel And Test.28 Test.27;
if Test.19 then
let Test.39 = 0i64;
let Test.28 = CallByName List.3 Test.2 Test.39;
let Test.30 = 0i64;
let Test.29 = CallByName List.3 Test.2 Test.30;
let Test.7 = Struct {Test.28, Test.29};
joinpoint Test.25:
let Test.18 = Array [];
ret Test.18;
in
let Test.19 = Index 0 Test.7;
let Test.20 = 1i64;
let Test.21 = Index 0 Test.19;
let Test.27 = lowlevel Eq Test.20 Test.21;
if Test.27 then
let Test.22 = Index 1 Test.7;
let Test.23 = 1i64;
let Test.24 = Index 0 Test.22;
let Test.26 = lowlevel Eq Test.23 Test.24;
if Test.26 then
let Test.17 = Index 0 Test.7;
let Test.3 = Index 1 Test.17;
let Test.16 = Index 1 Test.7;
@ -1277,13 +1349,15 @@ mod test_mono {
ret Test.8;
else
dec Test.2;
let Test.18 = Array [];
ret Test.18;
jump Test.25;
else
dec Test.2;
jump Test.25;
procedure Test.0 ():
let Test.42 = 1i64;
let Test.43 = 2i64;
let Test.6 = Array [Test.42, Test.43];
let Test.40 = 1i64;
let Test.41 = 2i64;
let Test.6 = Array [Test.40, Test.41];
let Test.5 = CallByName Test.1 Test.6;
ret Test.5;
"#
@ -1430,19 +1504,18 @@ mod test_mono {
ret Test.12;
procedure Test.1 (Test.2, Test.3):
jump Test.7 Test.2 Test.3;
joinpoint Test.7 Test.2 Test.3:
let Test.16 = true;
let Test.17 = 0i64;
let Test.18 = lowlevel Eq Test.17 Test.2;
let Test.15 = lowlevel And Test.18 Test.16;
if Test.15 then
let Test.15 = 0i64;
let Test.16 = lowlevel Eq Test.15 Test.2;
if Test.16 then
ret Test.3;
else
let Test.13 = 1i64;
let Test.10 = CallByName Num.15 Test.2 Test.13;
let Test.11 = CallByName Num.16 Test.2 Test.3;
jump Test.7 Test.10 Test.11;
in
jump Test.7 Test.2 Test.3;
procedure Test.0 ():
let Test.5 = 10i64;
@ -1767,20 +1840,19 @@ mod test_mono {
indoc!(
r#"
procedure Test.0 ():
let Test.8 = 0i64;
let Test.10 = 0i64;
let Test.12 = 0i64;
let Test.14 = 0i64;
let Test.16 = 1i64;
let Test.15 = Z Test.16;
let Test.13 = S Test.14 Test.15;
let Test.14 = 1i64;
let Test.13 = Z Test.14;
let Test.11 = S Test.12 Test.13;
let Test.2 = S Test.10 Test.11;
let Test.6 = true;
let Test.7 = 1i64;
let Test.8 = Index 0 Test.2;
let Test.9 = lowlevel Eq Test.7 Test.8;
let Test.5 = lowlevel And Test.9 Test.6;
if Test.5 then
let Test.9 = S Test.10 Test.11;
let Test.2 = S Test.8 Test.9;
let Test.5 = 1i64;
let Test.6 = Index 0 Test.2;
dec Test.2;
let Test.7 = lowlevel Eq Test.5 Test.6;
if Test.7 then
let Test.3 = 0i64;
ret Test.3;
else
@ -1810,36 +1882,35 @@ mod test_mono {
indoc!(
r#"
procedure Test.0 ():
let Test.14 = 0i64;
let Test.16 = 0i64;
let Test.18 = 0i64;
let Test.20 = 0i64;
let Test.22 = 0i64;
let Test.24 = 1i64;
let Test.23 = Z Test.24;
let Test.21 = S Test.22 Test.23;
let Test.19 = S Test.20 Test.21;
let Test.2 = S Test.18 Test.19;
let Test.14 = true;
let Test.15 = 0i64;
let Test.16 = Index 0 Test.2;
let Test.17 = lowlevel Eq Test.15 Test.16;
let Test.13 = lowlevel And Test.17 Test.14;
let Test.20 = 1i64;
let Test.19 = Z Test.20;
let Test.17 = S Test.18 Test.19;
let Test.15 = S Test.16 Test.17;
let Test.2 = S Test.14 Test.15;
let Test.11 = 0i64;
let Test.12 = Index 0 Test.2;
let Test.13 = lowlevel Eq Test.11 Test.12;
if Test.13 then
let Test.8 = true;
let Test.10 = 0i64;
let Test.9 = Index 1 Test.2;
inc Test.9;
let Test.11 = Index 0 Test.9;
dec Test.9;
let Test.12 = lowlevel Eq Test.10 Test.11;
let Test.7 = lowlevel And Test.12 Test.8;
if Test.7 then
let Test.7 = Index 1 Test.2;
inc Test.7;
let Test.8 = 0i64;
let Test.9 = Index 0 Test.7;
dec Test.7;
let Test.10 = lowlevel Eq Test.8 Test.9;
if Test.10 then
let Test.4 = Index 1 Test.2;
dec Test.2;
let Test.3 = 1i64;
ret Test.3;
else
dec Test.2;
let Test.5 = 0i64;
ret Test.5;
else
dec Test.2;
let Test.6 = 0i64;
ret Test.6;
"#
@ -1872,12 +1943,10 @@ mod test_mono {
ret Test.13;
procedure Test.1 (Test.6):
let Test.19 = true;
let Test.21 = false;
let Test.20 = Index 0 Test.6;
let Test.22 = lowlevel Eq Test.21 Test.20;
let Test.18 = lowlevel And Test.22 Test.19;
if Test.18 then
let Test.18 = Index 0 Test.6;
let Test.19 = false;
let Test.20 = lowlevel Eq Test.19 Test.18;
if Test.20 then
let Test.8 = Index 1 Test.6;
ret Test.8;
else
@ -1885,11 +1954,9 @@ mod test_mono {
ret Test.10;
procedure Test.1 (Test.6):
let Test.32 = true;
let Test.34 = false;
let Test.33 = Index 0 Test.6;
let Test.35 = lowlevel Eq Test.34 Test.33;
let Test.31 = lowlevel And Test.35 Test.32;
let Test.29 = Index 0 Test.6;
let Test.30 = false;
let Test.31 = lowlevel Eq Test.30 Test.29;
if Test.31 then
let Test.8 = 3i64;
ret Test.8;
@ -1898,19 +1965,19 @@ mod test_mono {
ret Test.10;
procedure Test.0 ():
let Test.38 = true;
let Test.37 = Struct {Test.38};
let Test.5 = CallByName Test.1 Test.37;
let Test.36 = false;
let Test.28 = Struct {Test.36};
let Test.3 = CallByName Test.1 Test.28;
let Test.26 = true;
let Test.27 = 11i64;
let Test.25 = Struct {Test.26, Test.27};
let Test.4 = CallByName Test.1 Test.25;
let Test.23 = false;
let Test.24 = 7i64;
let Test.15 = Struct {Test.23, Test.24};
let Test.34 = true;
let Test.33 = Struct {Test.34};
let Test.5 = CallByName Test.1 Test.33;
let Test.32 = false;
let Test.26 = Struct {Test.32};
let Test.3 = CallByName Test.1 Test.26;
let Test.24 = true;
let Test.25 = 11i64;
let Test.23 = Struct {Test.24, Test.25};
let Test.4 = CallByName Test.1 Test.23;
let Test.21 = false;
let Test.22 = 7i64;
let Test.15 = Struct {Test.21, Test.22};
let Test.2 = CallByName Test.1 Test.15;
let Test.14 = CallByName Num.16 Test.2 Test.3;
let Test.12 = CallByName Num.16 Test.14 Test.4;
@ -1943,30 +2010,33 @@ mod test_mono {
ret Test.6;
procedure Test.0 ():
let Test.17 = 0i64;
let Test.19 = 0i64;
let Test.21 = 0i64;
let Test.22 = 41i64;
let Test.20 = Just Test.21 Test.22;
let Test.2 = Just Test.19 Test.20;
let Test.10 = true;
let Test.20 = 41i64;
let Test.18 = Just Test.19 Test.20;
let Test.2 = Just Test.17 Test.18;
joinpoint Test.14:
let Test.8 = 1i64;
ret Test.8;
in
let Test.9 = Index 1 Test.2;
let Test.10 = 0i64;
let Test.11 = Index 0 Test.9;
let Test.16 = lowlevel Eq Test.10 Test.11;
if Test.16 then
let Test.12 = 0i64;
let Test.11 = Index 1 Test.2;
let Test.13 = Index 0 Test.11;
let Test.18 = lowlevel Eq Test.12 Test.13;
let Test.16 = lowlevel And Test.18 Test.10;
let Test.14 = 0i64;
let Test.15 = Index 0 Test.2;
let Test.17 = lowlevel Eq Test.14 Test.15;
let Test.9 = lowlevel And Test.17 Test.16;
if Test.9 then
let Test.13 = Index 0 Test.2;
let Test.15 = lowlevel Eq Test.12 Test.13;
if Test.15 then
let Test.7 = Index 1 Test.2;
let Test.3 = Index 1 Test.7;
let Test.5 = 1i64;
let Test.4 = CallByName Num.14 Test.3 Test.5;
ret Test.4;
else
let Test.8 = 1i64;
ret Test.8;
jump Test.14;
else
jump Test.14;
"#
),
)
@ -2055,18 +2125,18 @@ mod test_mono {
indoc!(
r#"
procedure List.3 (#Attr.2, #Attr.3):
let Test.42 = lowlevel ListLen #Attr.2;
let Test.38 = lowlevel NumLt #Attr.3 Test.42;
if Test.38 then
let Test.40 = 1i64;
let Test.41 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
let Test.39 = Ok Test.40 Test.41;
ret Test.39;
let Test.40 = lowlevel ListLen #Attr.2;
let Test.36 = lowlevel NumLt #Attr.3 Test.40;
if Test.36 then
let Test.38 = 1i64;
let Test.39 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
let Test.37 = Ok Test.38 Test.39;
ret Test.37;
else
let Test.36 = 0i64;
let Test.37 = Struct {};
let Test.35 = Err Test.36 Test.37;
ret Test.35;
let Test.34 = 0i64;
let Test.35 = Struct {};
let Test.33 = Err Test.34 Test.35;
ret Test.33;
procedure List.4 (#Attr.2, #Attr.3, #Attr.4):
let Test.18 = lowlevel ListLen #Attr.2;
@ -2078,21 +2148,23 @@ mod test_mono {
ret #Attr.2;
procedure Test.1 (Test.2, Test.3, Test.4):
let Test.33 = CallByName List.3 Test.4 Test.2;
let Test.34 = CallByName List.3 Test.4 Test.3;
let Test.12 = Struct {Test.33, Test.34};
let Test.23 = true;
let Test.25 = 1i64;
let Test.24 = Index 0 Test.12;
let Test.26 = Index 0 Test.24;
let Test.32 = lowlevel Eq Test.25 Test.26;
let Test.30 = lowlevel And Test.32 Test.23;
let Test.28 = 1i64;
let Test.27 = Index 1 Test.12;
let Test.29 = Index 0 Test.27;
let Test.31 = lowlevel Eq Test.28 Test.29;
let Test.22 = lowlevel And Test.31 Test.30;
if Test.22 then
let Test.31 = CallByName List.3 Test.4 Test.2;
let Test.32 = CallByName List.3 Test.4 Test.3;
let Test.12 = Struct {Test.31, Test.32};
joinpoint Test.28:
let Test.21 = Array [];
ret Test.21;
in
let Test.22 = Index 0 Test.12;
let Test.23 = 1i64;
let Test.24 = Index 0 Test.22;
let Test.30 = lowlevel Eq Test.23 Test.24;
if Test.30 then
let Test.25 = Index 1 Test.12;
let Test.26 = 1i64;
let Test.27 = Index 0 Test.25;
let Test.29 = lowlevel Eq Test.26 Test.27;
if Test.29 then
let Test.20 = Index 0 Test.12;
let Test.5 = Index 1 Test.20;
let Test.19 = Index 1 Test.12;
@ -2102,14 +2174,16 @@ mod test_mono {
ret Test.13;
else
dec Test.4;
let Test.21 = Array [];
ret Test.21;
jump Test.28;
else
dec Test.4;
jump Test.28;
procedure Test.0 ():
let Test.9 = 0i64;
let Test.10 = 0i64;
let Test.43 = 1i64;
let Test.11 = Array [Test.43];
let Test.41 = 1i64;
let Test.11 = Array [Test.41];
let Test.8 = CallByName Test.1 Test.9 Test.10 Test.11;
ret Test.8;
"#

View file

@ -1047,7 +1047,7 @@ mod test_reporting {
r#"
TYPE MISMATCH
Something is off with the 1st branch of this `if` expression:
Something is off with the `then` branch of this `if` expression:
2 x = if True then 3.14 else 4
^^^^
@ -1084,12 +1084,14 @@ mod test_reporting {
r#"
TYPE MISMATCH
Something is off with the 1st branch of this `when` expression:
Something is off with the body of the `x` definition:
4 _ -> 3.14
^^^^
1 x : Int
2 x =
3> when True is
4> _ -> 3.14
The 1st branch is a float of type:
This `when`expression produces:
Float
@ -1123,15 +1125,15 @@ mod test_reporting {
1 x : Int -> Int
2 x = \_ -> 3.14
^^^^^^^^^^
^^^^
The body is an anonymous function of type:
The body is a float of type:
Int -> Float
Float
But the type annotation on `x` says it should be:
Int -> Int
Int
Tip: You can convert between Int and Float using functions like
`Num.toFloat` and `Num.round`.
@ -1664,7 +1666,7 @@ mod test_reporting {
r#"
TYPE MISMATCH
This `if` has an `else` branch with a different type from its `then` branch:
Something is off with the `else` branch of this `if` expression:
2 f = \x, y -> if True then x else y
^
@ -1673,12 +1675,10 @@ mod test_reporting {
b
but the `then` branch has the type:
But the type annotation on `f` says it should be:
a
I need all branches in an `if` to have the same type!
Tip: Your type annotation uses `b` and `a` as separate type variables.
Your code seems to be saying they are the same though. Maybe they
should be the same your type annotation? Maybe your code uses them in
@ -1707,15 +1707,15 @@ mod test_reporting {
1 f : Bool -> msg
2 f = \_ -> Foo
^^^^^^^^^
^^^
The body is an anonymous function of type:
This `Foo` global tag has the type:
Bool -> [ Foo ]a
[ Foo ]a
But the type annotation on `f` says it should be:
Bool -> msg
msg
Tip: The type annotation uses the type variable `msg` to say that this
definition can produce any type of value. But in the body I see that
@ -1830,18 +1830,19 @@ mod test_reporting {
Something is off with the body of the `f` definition:
1 f : Bool -> Int
2> f = \_ ->
3> ok = 3
4>
5> Ok
2 f = \_ ->
3 ok = 3
4
5 Ok
^^
The body is an anonymous function of type:
This `Ok` global tag has the type:
Bool -> [ Ok ]a
[ Ok ]a
But the type annotation on `f` says it should be:
Bool -> Int
Int
"#
),
)
@ -2141,15 +2142,15 @@ mod test_reporting {
1 f : [ A ] -> [ A, B ]
2 f = \a -> a
^^^^^^^
^
The body is an anonymous function of type:
This `a` value is a:
[ A ] -> [ A ]
[ A ]
But the type annotation on `f` says it should be:
[ A ] -> [ A, B ]
[ A, B ]
Tip: Looks like a closed tag union does not have the `B` tag.
@ -2179,15 +2180,15 @@ mod test_reporting {
1 f : [ A ] -> [ A, B, C ]
2 f = \a -> a
^^^^^^^
^
The body is an anonymous function of type:
This `a` value is a:
[ A ] -> [ A ]
[ A ]
But the type annotation on `f` says it should be:
[ A ] -> [ A, B, C ]
[ A, B, C ]
Tip: Looks like a closed tag union does not have the `C` and `B` tags.

View file

@ -3519,4 +3519,99 @@ mod solve_expr {
"Int, Int, List (Num a), Int, Num a -> [ Pair Int (List (Num a)) ]",
);
}
#[test]
fn rbtree_old_balance_simplified() {
infer_eq_without_problem(
indoc!(
r#"
app Test provides [ main ] imports []
Dict k : [ Node k (Dict k) (Dict k), Empty ]
balance : k, Dict k -> Dict k
balance = \key, left ->
Node key left Empty
main : Dict Int
main =
balance 0 Empty
"#
),
"Dict Int",
);
}
#[test]
fn rbtree_balance_simplified() {
infer_eq_without_problem(
indoc!(
r#"
app Test provides [ main ] imports []
Dict k : [ Node k (Dict k) (Dict k), Empty ]
node = \x,y,z -> Node x y z
balance : k, Dict k -> Dict k
balance = \key, left ->
node key left Empty
main : Dict Int
main =
balance 0 Empty
"#
),
"Dict Int",
);
}
#[test]
fn rbtree_balance() {
infer_eq_without_problem(
indoc!(
r#"
app Test provides [ main ] imports []
NodeColor : [ Red, Black ]
Dict k v : [ Node NodeColor k v (Dict k v) (Dict k v), Empty ]
balance : NodeColor, k, v, Dict k v, Dict k v -> Dict k v
balance = \color, key, value, left, right ->
when right is
Node Red rK rV rLeft rRight ->
when left is
Node Red lK lV lLeft lRight ->
Node
Red
key
value
(Node Black lK lV lLeft lRight)
(Node Black rK rV rLeft rRight)
_ ->
Node color rK rV (Node Red key value left rLeft) rRight
_ ->
when left is
Node Red lK lV (Node Red llK llV llLeft llRight) lRight ->
Node
Red
lK
lV
(Node Black llK llV llLeft llRight)
(Node Black key value lRight right)
_ ->
Node color key value left right
main : Dict Int Int
main =
balance Red 0 0 Empty Empty
"#
),
"Dict Int Int",
);
}
}

View file

@ -180,7 +180,8 @@ fn unify_alias(
// Alias wins
merge(subs, &ctx, Alias(symbol, args.to_owned(), real_var))
}
RecursionVar { .. } | RigidVar(_) => unify_pool(subs, pool, real_var, ctx.second),
RecursionVar { structure, .. } => unify_pool(subs, pool, real_var, *structure),
RigidVar(_) => unify_pool(subs, pool, real_var, ctx.second),
Alias(other_symbol, other_args, other_real_var) => {
if symbol == *other_symbol {
if args.len() == other_args.len() {
@ -240,7 +241,8 @@ fn unify_structure(
problems
}
FlatType::RecursiveTagUnion(_, _, _) => {
FlatType::RecursiveTagUnion(rec, _, _) => {
debug_assert!(is_recursion_var(subs, *rec));
let structure_rank = subs.get(*structure).rank;
let self_rank = subs.get(ctx.first).rank;
let other_rank = subs.get(ctx.second).rank;
@ -593,7 +595,7 @@ fn unify_tag_union_not_recursive_recursive(
tag_problems
} else {
let flat_type = FlatType::RecursiveTagUnion(recursion_var, unique_tags2, rec2.ext);
let flat_type = FlatType::TagUnion(unique_tags2, rec2.ext);
let sub_record = fresh(subs, pool, ctx, Structure(flat_type));
let ext_problems = unify_pool(subs, pool, rec1.ext, sub_record);
@ -616,7 +618,7 @@ fn unify_tag_union_not_recursive_recursive(
tag_problems
}
} else if unique_tags2.is_empty() {
let flat_type = FlatType::RecursiveTagUnion(recursion_var, unique_tags1, rec1.ext);
let flat_type = FlatType::TagUnion(unique_tags1, rec1.ext);
let sub_record = fresh(subs, pool, ctx, Structure(flat_type));
let ext_problems = unify_pool(subs, pool, sub_record, rec2.ext);
@ -641,8 +643,8 @@ fn unify_tag_union_not_recursive_recursive(
let other_tags = union(unique_tags1.clone(), &unique_tags2);
let ext = fresh(subs, pool, ctx, Content::FlexVar(None));
let flat_type1 = FlatType::RecursiveTagUnion(recursion_var, unique_tags1, ext);
let flat_type2 = FlatType::RecursiveTagUnion(recursion_var, unique_tags2, ext);
let flat_type1 = FlatType::TagUnion(unique_tags1, ext);
let flat_type2 = FlatType::TagUnion(unique_tags2, ext);
let sub1 = fresh(subs, pool, ctx, Structure(flat_type1));
let sub2 = fresh(subs, pool, ctx, Structure(flat_type2));
@ -855,6 +857,7 @@ fn unify_shared_tags(
new_tags.extend(fields.into_iter());
let flat_type = if let Some(rec) = recursion_var {
debug_assert!(is_recursion_var(subs, rec));
FlatType::RecursiveTagUnion(rec, new_tags, new_ext_var)
} else {
FlatType::TagUnion(new_tags, new_ext_var)
@ -924,6 +927,7 @@ fn unify_flat_type(
}
(RecursiveTagUnion(recursion_var, tags1, ext1), TagUnion(tags2, ext2)) => {
debug_assert!(is_recursion_var(subs, *recursion_var));
// this never happens in type-correct programs, but may happen if there is a type error
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
@ -939,6 +943,7 @@ fn unify_flat_type(
}
(TagUnion(tags1, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => {
debug_assert!(is_recursion_var(subs, *recursion_var));
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
@ -946,6 +951,8 @@ fn unify_flat_type(
}
(RecursiveTagUnion(rec1, tags1, ext1), RecursiveTagUnion(rec2, tags2, ext2)) => {
debug_assert!(is_recursion_var(subs, *rec1));
debug_assert!(is_recursion_var(subs, *rec2));
let union1 = gather_tags(subs, tags1.clone(), *ext1);
let union2 = gather_tags(subs, tags2.clone(), *ext2);
@ -1153,13 +1160,16 @@ fn unify_recursion(
// unify the structure variable with this Structure
unify_pool(subs, pool, structure, ctx.second)
}
RigidVar(_) => mismatch!("RecursionVar {:?} with rigid {:?}", ctx.first, &other),
FlexVar(_) | RigidVar(_) => {
// TODO special-case boolean here
// In all other cases, if left is flex, defer to right.
// (This includes using right's name if both are flex and named.)
merge(subs, ctx, other.clone())
}
FlexVar(_) => merge(
subs,
ctx,
RecursionVar {
structure,
opt_name: opt_name.clone(),
},
),
Alias(_, _, actual) => {
// look at the type the alias stands for
@ -1227,3 +1237,7 @@ fn gather_tags(
_ => TagUnionStructure { tags, ext: var },
}
}
fn is_recursion_var(subs: &Subs, var: Variable) -> bool {
matches!(subs.get_without_compacting(var).content, Content::RecursionVar { .. })
}