fix recursion issue in type inference

This commit is contained in:
Folkert 2020-09-07 16:35:49 +02:00
parent 0a034c474a
commit 4522fe14fc
4 changed files with 86 additions and 31 deletions

View file

@ -555,9 +555,7 @@ mod gen_primitives {
Nil -> 0 Nil -> 0
Cons _ rest -> 1 + length rest Cons _ rest -> 1 + length rest
wrapper = { list: nil } length nil + length nil
length wrapper.list
"# "#
), ),
0, 0,
@ -611,8 +609,7 @@ mod gen_primitives {
Cons _ rest -> 1 + length rest Cons _ rest -> 1 + length rest
# TODO actually calculate twice length one + length one
2 * length one
"# "#
), ),
2, 2,

View file

@ -625,7 +625,7 @@ mod test_mono {
fn when_joinpoint() { fn when_joinpoint() {
compiles_to_ir( compiles_to_ir(
r#" r#"
main = \{} -> main = \{} ->
x : [ Red, White, Blue ] x : [ Red, White, Blue ]
x = Blue x = Blue
@ -647,15 +647,15 @@ mod test_mono {
case 1: case 1:
let Test.9 = 1i64; let Test.9 = 1i64;
jump Test.8 Test.9; jump Test.8 Test.9;
case 2: case 2:
let Test.10 = 2i64; let Test.10 = 2i64;
jump Test.8 Test.10; jump Test.8 Test.10;
default: default:
let Test.11 = 3i64; let Test.11 = 3i64;
jump Test.8 Test.11; jump Test.8 Test.11;
joinpoint Test.8 Test.3: joinpoint Test.8 Test.3:
ret Test.3; ret Test.3;
@ -724,7 +724,7 @@ mod test_mono {
fn when_on_result() { fn when_on_result() {
compiles_to_ir( compiles_to_ir(
r#" r#"
main = \{} -> main = \{} ->
x : Result Int Int x : Result Int Int
x = Ok 2 x = Ok 2
@ -824,7 +824,7 @@ mod test_mono {
compiles_to_ir( compiles_to_ir(
indoc!( indoc!(
r#" r#"
main = \{} -> main = \{} ->
when 10 is when 10 is
x if x == 5 -> 0 x if x == 5 -> 0
_ -> 42 _ -> 42
@ -1317,10 +1317,10 @@ mod test_mono {
r#" r#"
factorial = \n, accum -> factorial = \n, accum ->
when n is when n is
0 -> 0 ->
accum accum
_ -> _ ->
factorial (n - 1) (n * accum) factorial (n - 1) (n * accum)
factorial 10 1 factorial 10 1
@ -1368,11 +1368,11 @@ mod test_mono {
isNil : ConsList a -> Bool isNil : ConsList a -> Bool
isNil = \list -> isNil = \list ->
when list is when list is
Nil -> True Nil -> True
Cons _ _ -> False Cons _ _ -> False
isNil (Cons 0x2 Nil) isNil (Cons 0x2 Nil)
"#, "#,
indoc!( indoc!(
r#" r#"
@ -1411,12 +1411,12 @@ mod test_mono {
hasNone : ConsList (Maybe a) -> Bool hasNone : ConsList (Maybe a) -> Bool
hasNone = \list -> hasNone = \list ->
when list is when list is
Nil -> False Nil -> False
Cons Nothing _ -> True Cons Nothing _ -> True
Cons (Just _) xs -> hasNone xs Cons (Just _) xs -> hasNone xs
hasNone (Cons (Just 3) Nil) hasNone (Cons (Just 3) Nil)
"#, "#,
indoc!( indoc!(
r#" r#"
@ -1475,7 +1475,7 @@ mod test_mono {
fn fst() { fn fst() {
compiles_to_ir( compiles_to_ir(
r#" r#"
fst = \x, y -> x fst = \x, y -> x
fst [1,2,3] [3,2,1] fst [1,2,3] [3,2,1]
"#, "#,
@ -1512,7 +1512,7 @@ mod test_mono {
add : List Int -> List Int add : List Int -> List Int
add = \y -> List.set y 0 0 add = \y -> List.set y 0 0
List.len (add x) + List.len x List.len (add x) + List.len x
"# "#
), ),
@ -1563,7 +1563,7 @@ mod test_mono {
compiles_to_ir( compiles_to_ir(
indoc!( indoc!(
r#" r#"
main = \{} -> main = \{} ->
List.get [1,2,3] 0 List.get [1,2,3] 0
main {} main {}
@ -1736,9 +1736,9 @@ mod test_mono {
{ x: Blue, y ? 3 } -> y { x: Blue, y ? 3 } -> y
{ x: Red, y ? 5 } -> y { x: Red, y ? 5 } -> y
a = f { x: Blue, y: 7 } a = f { x: Blue, y: 7 }
b = f { x: Blue } b = f { x: Blue }
c = f { x: Red, y: 11 } c = f { x: Red, y: 11 }
d = f { x: Red } d = f { x: Red }
a * b * c * d a * b * c * d
@ -1848,4 +1848,58 @@ mod test_mono {
), ),
) )
} }
#[test]
fn linked_list_length_twice() {
compiles_to_ir(
indoc!(
r#"
LinkedList a : [ Nil, Cons a (LinkedList a) ]
nil : LinkedList Int
nil = Nil
length : LinkedList a -> Int
length = \list ->
when list is
Nil -> 0
Cons _ rest -> 1 + length rest
length nil + length nil
"#
),
indoc!(
r#"
procedure Num.14 (#Attr.2, #Attr.3):
let Test.14 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.14;
procedure Test.2 (Test.4):
let Test.16 = true;
let Test.17 = 1i64;
let Test.18 = Index 0 Test.4;
let Test.19 = lowlevel Eq Test.17 Test.18;
let Test.15 = lowlevel And Test.19 Test.16;
if Test.15 then
dec Test.4;
let Test.10 = 0i64;
ret Test.10;
else
let Test.5 = Index 2 Test.4;
dec Test.4;
let Test.12 = 1i64;
let Test.13 = CallByName Test.2 Test.5;
let Test.11 = CallByName Num.14 Test.12 Test.13;
ret Test.11;
let Test.9 = 1i64;
let Test.1 = Nil Test.9;
let Test.7 = CallByName Test.2 Test.1;
let Test.8 = CallByName Test.2 Test.1;
let Test.6 = CallByName Num.14 Test.7 Test.8;
ret Test.6;
"#
),
)
}
} }

View file

@ -2005,7 +2005,7 @@ mod solve_uniq_expr {
toAs toAs
"# "#
), ),
"Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (* | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))" "Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (c | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))"
); );
} }
@ -2039,7 +2039,7 @@ mod solve_uniq_expr {
toAs toAs
"# "#
), ),
"Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (* | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))" "Attr Shared (Attr Shared (Attr a q -> Attr b p), Attr (c | a | b) (ListA (Attr b p) (Attr a q)) -> Attr * (ConsList (Attr b p)))"
); );
} }
@ -2789,7 +2789,7 @@ mod solve_uniq_expr {
cheapestOpen cheapestOpen
"# "#
), ),
"Attr * (Attr * (Attr Shared position -> Attr * Float), Attr (* | * | * | *) (Model (Attr Shared position)) -> Attr * (Result (Attr Shared position) (Attr * [ KeyNotFound ]*)))" "Attr * (Attr * (Attr Shared position -> Attr * Float), Attr (* | * | a | b) (Model (Attr Shared position)) -> Attr * (Result (Attr Shared position) (Attr * [ KeyNotFound ]*)))"
) )
}); });
} }

View file

@ -581,8 +581,8 @@ fn unify_shared_tags(
if let Some(rvar) = recursion_var { if let Some(rvar) = recursion_var {
match attr_wrapped { match attr_wrapped {
None => { None => {
if expected == rvar { if subs.equivalent(expected, rvar) {
if actual == rvar { if subs.equivalent(actual, rvar) {
problems.extend(unify_pool(subs, pool, expected, actual)); problems.extend(unify_pool(subs, pool, expected, actual));
} else { } else {
problems.extend(unify_pool(subs, pool, actual, ctx.second)); problems.extend(unify_pool(subs, pool, actual, ctx.second));
@ -610,8 +610,8 @@ fn unify_shared_tags(
} }
} }
Some((_expected_uvar, inner_expected, _actual_uvar, inner_actual)) => { Some((_expected_uvar, inner_expected, _actual_uvar, inner_actual)) => {
if inner_expected == rvar { if subs.equivalent(inner_expected, rvar) {
if inner_actual == rvar { if subs.equivalent(inner_actual, rvar) {
problems.extend(unify_pool(subs, pool, actual, expected)); problems.extend(unify_pool(subs, pool, actual, expected));
} else { } else {
problems.extend(unify_pool(subs, pool, inner_actual, ctx.second)); problems.extend(unify_pool(subs, pool, inner_actual, ctx.second));
@ -747,7 +747,7 @@ fn unify_flat_type(
ctx, ctx,
union1, union1,
union2, union2,
(None, Some(*recursion_var)), (Some(*recursion_var), None),
) )
} }
@ -774,7 +774,11 @@ fn unify_flat_type(
(Boolean(b1), Boolean(b2)) => { (Boolean(b1), Boolean(b2)) => {
use Bool::*; use Bool::*;
match (b1, b2) {
let b1 = b1.simplify(subs);
let b2 = b2.simplify(subs);
match (&b1, &b2) {
(Shared, Shared) => merge(subs, ctx, Structure(left.clone())), (Shared, Shared) => merge(subs, ctx, Structure(left.clone())),
(Shared, Container(cvar, mvars)) => { (Shared, Container(cvar, mvars)) => {
let mut outcome = vec![]; let mut outcome = vec![];