diff --git a/src/uniqueness/mod.rs b/src/uniqueness/mod.rs index e6af814bef..e4be9ccd43 100644 --- a/src/uniqueness/mod.rs +++ b/src/uniqueness/mod.rs @@ -459,7 +459,9 @@ pub fn constrain_expr( expected, ) } - Closure(fn_var, _symbol, _recursion, args, boxed) => { + Closure(fn_var, _symbol, recursion, args, boxed) => { + use crate::can::expr::Recursive; + let (loc_body_expr, ret_var) = &**boxed; let mut state = PatternState { headers: SendMap::default(), @@ -485,11 +487,18 @@ pub fn constrain_expr( vars.push(*pattern_var); } - let fn_uniq_var = var_store.fresh(); - vars.push(fn_uniq_var); + let fn_uniq_type; + if let Recursive::NotRecursive = recursion { + let fn_uniq_var = var_store.fresh(); + vars.push(fn_uniq_var); + fn_uniq_type = Bool::Variable(fn_uniq_var); + } else { + // recursive definitions MUST be Shared + fn_uniq_type = Bool::Zero + } let fn_type = constrain::attr_type( - Bool::Variable(fn_uniq_var), + fn_uniq_type, Type::Function(pattern_types, Box::new(ret_type.clone())), ); let body_type = Expected::NoExpectation(ret_type); diff --git a/tests/test_infer.rs b/tests/test_infer.rs index e54348564f..a7d60c0a4b 100644 --- a/tests/test_infer.rs +++ b/tests/test_infer.rs @@ -1246,23 +1246,22 @@ mod test_infer { ); } + // TODO add more realistic function when able #[test] fn integer_sum() { - with_larger_debug_stack(|| { - infer_eq_without_problem( - indoc!( - r#" - f = \n -> + infer_eq_without_problem( + indoc!( + r#" + f = \n -> when n is 0 -> 0 _ -> f n f "# - ), - "Int -> Int", - ); - }); + ), + "Int -> Int", + ); } // currently fails, the rank of Cons's ext_var is 3, where 2 is the highest pool diff --git a/tests/test_uniqueness_infer.rs b/tests/test_uniqueness_infer.rs index 249fae6496..bfc6f762b9 100644 --- a/tests/test_uniqueness_infer.rs +++ b/tests/test_uniqueness_infer.rs @@ -1231,4 +1231,41 @@ mod test_infer_uniq { "Attr.Attr * Int", ); } + + // TODO add more realistic recursive example when able + #[test] + fn factorial_is_shared() { + infer_eq_without_problem( + indoc!( + r#" + factorial = \n -> + when n is + 0 -> 1 + 1 -> 1 + m -> factorial m + + factorial + "# + ), + "Attr.Attr Attr.Shared (Attr.Attr * Int -> Attr.Attr * Int)", + ); + } + + // TODO add more realistic recursive example when able + #[test] + fn factorial_without_recursive_case_can_be_unique() { + infer_eq_without_problem( + indoc!( + r#" + factorial = \n -> + when n is + 0 -> 1 + _ -> 1 + + factorial + "# + ), + "Attr.Attr * (Attr.Attr * Int -> Attr.Attr * Int)", + ); + } }