diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index d461a42a36..9b5e07f7a3 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -555,9 +555,7 @@ mod gen_primitives { Nil -> 0 Cons _ rest -> 1 + length rest - wrapper = { list: nil } - - length wrapper.list + length nil + length nil "# ), 0, @@ -611,8 +609,7 @@ mod gen_primitives { Cons _ rest -> 1 + length rest - # TODO actually calculate twice - 2 * length one + length one + length one "# ), 2, diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 602c49eb05..dc6187b273 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -625,7 +625,7 @@ mod test_mono { fn when_joinpoint() { compiles_to_ir( r#" - main = \{} -> + main = \{} -> x : [ Red, White, Blue ] x = Blue @@ -647,15 +647,15 @@ mod test_mono { case 1: let Test.9 = 1i64; jump Test.8 Test.9; - + case 2: let Test.10 = 2i64; jump Test.8 Test.10; - + default: let Test.11 = 3i64; jump Test.8 Test.11; - + joinpoint Test.8 Test.3: ret Test.3; @@ -724,7 +724,7 @@ mod test_mono { fn when_on_result() { compiles_to_ir( r#" - main = \{} -> + main = \{} -> x : Result Int Int x = Ok 2 @@ -824,7 +824,7 @@ mod test_mono { compiles_to_ir( indoc!( r#" - main = \{} -> + main = \{} -> when 10 is x if x == 5 -> 0 _ -> 42 @@ -1317,10 +1317,10 @@ mod test_mono { r#" factorial = \n, accum -> when n is - 0 -> + 0 -> accum - _ -> + _ -> factorial (n - 1) (n * accum) factorial 10 1 @@ -1368,11 +1368,11 @@ mod test_mono { isNil : ConsList a -> Bool isNil = \list -> - when list is + when list is Nil -> True Cons _ _ -> False - isNil (Cons 0x2 Nil) + isNil (Cons 0x2 Nil) "#, indoc!( r#" @@ -1411,12 +1411,12 @@ mod test_mono { hasNone : ConsList (Maybe a) -> Bool hasNone = \list -> - when list is + when list is Nil -> False Cons Nothing _ -> True Cons (Just _) xs -> hasNone xs - hasNone (Cons (Just 3) Nil) + hasNone (Cons (Just 3) Nil) "#, indoc!( r#" @@ -1475,7 +1475,7 @@ mod test_mono { fn fst() { compiles_to_ir( r#" - fst = \x, y -> x + fst = \x, y -> x fst [1,2,3] [3,2,1] "#, @@ -1512,7 +1512,7 @@ mod test_mono { add : List Int -> List Int add = \y -> List.set y 0 0 - + List.len (add x) + List.len x "# ), @@ -1563,7 +1563,7 @@ mod test_mono { compiles_to_ir( indoc!( r#" - main = \{} -> + main = \{} -> List.get [1,2,3] 0 main {} @@ -1736,9 +1736,9 @@ mod test_mono { { x: Blue, y ? 3 } -> y { x: Red, y ? 5 } -> y - a = f { x: Blue, y: 7 } + a = f { x: Blue, y: 7 } b = f { x: Blue } - c = f { x: Red, y: 11 } + c = f { x: Red, y: 11 } d = f { x: Red } a * b * c * d @@ -1848,4 +1848,58 @@ mod test_mono { ), ) } + + #[test] + fn linked_list_length_twice() { + compiles_to_ir( + indoc!( + r#" + LinkedList a : [ Nil, Cons a (LinkedList a) ] + + nil : LinkedList Int + nil = Nil + + length : LinkedList a -> Int + length = \list -> + when list is + Nil -> 0 + Cons _ rest -> 1 + length rest + + length nil + length nil + "# + ), + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.14 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.14; + + procedure Test.2 (Test.4): + let Test.16 = true; + let Test.17 = 1i64; + let Test.18 = Index 0 Test.4; + let Test.19 = lowlevel Eq Test.17 Test.18; + let Test.15 = lowlevel And Test.19 Test.16; + if Test.15 then + dec Test.4; + let Test.10 = 0i64; + ret Test.10; + else + let Test.5 = Index 2 Test.4; + dec Test.4; + let Test.12 = 1i64; + let Test.13 = CallByName Test.2 Test.5; + let Test.11 = CallByName Num.14 Test.12 Test.13; + ret Test.11; + + let Test.9 = 1i64; + let Test.1 = Nil Test.9; + let Test.7 = CallByName Test.2 Test.1; + let Test.8 = CallByName Test.2 Test.1; + let Test.6 = CallByName Num.14 Test.7 Test.8; + ret Test.6; + "# + ), + ) + } } diff --git a/compiler/solve/tests/solve_uniq_expr.rs b/compiler/solve/tests/solve_uniq_expr.rs index d50c1f808a..06ca77ed59 100644 --- a/compiler/solve/tests/solve_uniq_expr.rs +++ b/compiler/solve/tests/solve_uniq_expr.rs @@ -2005,7 +2005,7 @@ mod solve_uniq_expr { toAs "# ), - "Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (* | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))" + "Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (c | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))" ); } @@ -2039,7 +2039,7 @@ mod solve_uniq_expr { toAs "# ), - "Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (* | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))" + "Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (c | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))" ); } @@ -2789,7 +2789,7 @@ mod solve_uniq_expr { cheapestOpen "# ), - "Attr * (Attr * (Attr Shared position -> Attr * Float), Attr (* | * | * | *) (Model (Attr Shared position)) -> Attr * (Result (Attr Shared position) (Attr * [ KeyNotFound ]*)))" + "Attr * (Attr * (Attr Shared position -> Attr * Float), Attr (* | * | a | b) (Model (Attr Shared position)) -> Attr * (Result (Attr Shared position) (Attr * [ KeyNotFound ]*)))" ) }); } diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index f62deead83..b42aae3b2a 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -581,8 +581,8 @@ fn unify_shared_tags( if let Some(rvar) = recursion_var { match attr_wrapped { None => { - if expected == rvar { - if actual == rvar { + if subs.equivalent(expected, rvar) { + if subs.equivalent(actual, rvar) { problems.extend(unify_pool(subs, pool, expected, actual)); } else { problems.extend(unify_pool(subs, pool, actual, ctx.second)); @@ -610,8 +610,8 @@ fn unify_shared_tags( } } Some((_expected_uvar, inner_expected, _actual_uvar, inner_actual)) => { - if inner_expected == rvar { - if inner_actual == rvar { + if subs.equivalent(inner_expected, rvar) { + if subs.equivalent(inner_actual, rvar) { problems.extend(unify_pool(subs, pool, actual, expected)); } else { problems.extend(unify_pool(subs, pool, inner_actual, ctx.second)); @@ -747,7 +747,7 @@ fn unify_flat_type( ctx, union1, union2, - (None, Some(*recursion_var)), + (Some(*recursion_var), None), ) } @@ -774,7 +774,11 @@ fn unify_flat_type( (Boolean(b1), Boolean(b2)) => { use Bool::*; - match (b1, b2) { + + let b1 = b1.simplify(subs); + let b2 = b2.simplify(subs); + + match (&b1, &b2) { (Shared, Shared) => merge(subs, ctx, Structure(left.clone())), (Shared, Container(cvar, mvars)) => { let mut outcome = vec![];