From 2856a382364404e8a4cbb185ea9c7e8dcda475e6 Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Mon, 18 Apr 2022 11:03:25 -0400 Subject: [PATCH] Solve all when branch pattern constraints before solving their bodies Closes #2886 --- compiler/constrain/src/expr.rs | 214 +++++++++++++++-------------- compiler/solve/tests/solve_expr.rs | 17 +++ 2 files changed, 127 insertions(+), 104 deletions(-) diff --git a/compiler/constrain/src/expr.rs b/compiler/constrain/src/expr.rs index 9fc45351dc..6eff9c376d 100644 --- a/compiler/constrain/src/expr.rs +++ b/compiler/constrain/src/expr.rs @@ -596,109 +596,90 @@ pub fn constrain_expr( NoExpectation(cond_type.clone()), ); - let mut branch_constraints = Vec::with_capacity(branches.len() + 1); - branch_constraints.push(expr_con); + let branch_var = *expr_var; + let branch_type = Variable(branch_var); - match &expected { - FromAnnotation(name, arity, ann_source, _typ) => { - // NOTE deviation from elm. - // - // in elm, `_typ` is used, but because we have this `expr_var` too - // and need to constrain it, this is what works and gives better error messages - let typ = Type::Variable(*expr_var); - - for (index, when_branch) in branches.iter().enumerate() { - let pattern_region = - Region::across_all(when_branch.patterns.iter().map(|v| &v.region)); - - let branch_con = constrain_when_branch( - constraints, - env, - when_branch.value.region, - when_branch, - PExpected::ForReason( - PReason::WhenMatch { - index: HumanIndex::zero_based(index), - }, - cond_type.clone(), - pattern_region, - ), - FromAnnotation( - name.clone(), - *arity, - AnnotationSource::TypedWhenBranch { - index: HumanIndex::zero_based(index), - region: ann_source.region(), - }, - typ.clone(), - ), - ); - - branch_constraints.push(branch_con); + let branch_expr_reason = + |expected: &Expected, index, branch_region| match expected { + FromAnnotation(name, arity, ann_source, _typ) => { + // NOTE deviation from elm. + // + // in elm, `_typ` is used, but because we have this `expr_var` too + // and need to constrain it, this is what works and gives better error messages + FromAnnotation( + name.clone(), + *arity, + AnnotationSource::TypedWhenBranch { + index, + region: ann_source.region(), + }, + branch_type.clone(), + ) } - branch_constraints.push(constraints.equal_types_var( - *expr_var, - expected, - Category::When, - region, - )); + _ => ForReason( + Reason::WhenBranch { index }, + branch_type.clone(), + branch_region, + ), + }; - return constraints.exists_many([cond_var, *expr_var], branch_constraints); - } + let mut branch_cons = Vec::with_capacity(branches.len()); + let mut pattern_cons = Vec::with_capacity(branches.len()); - _ => { - let branch_var = *expr_var; - let branch_type = Variable(branch_var); - let mut branch_cons = Vec::with_capacity(branches.len()); + for (index, when_branch) in branches.iter().enumerate() { + let pattern_region = + Region::across_all(when_branch.patterns.iter().map(|v| &v.region)); - for (index, when_branch) in branches.iter().enumerate() { - let pattern_region = - Region::across_all(when_branch.patterns.iter().map(|v| &v.region)); - let branch_con = constrain_when_branch( - constraints, - env, - region, - when_branch, - PExpected::ForReason( - PReason::WhenMatch { - index: HumanIndex::zero_based(index), - }, - cond_type.clone(), - pattern_region, - ), - ForReason( - Reason::WhenBranch { - index: HumanIndex::zero_based(index), - }, - branch_type.clone(), - when_branch.value.region, - ), - ); + let (pattern_con, branch_con) = constrain_when_branch( + constraints, + env, + region, + when_branch, + PExpected::ForReason( + PReason::WhenMatch { + index: HumanIndex::zero_based(index), + }, + cond_type.clone(), + pattern_region, + ), + branch_expr_reason( + &expected, + HumanIndex::zero_based(index), + when_branch.value.region, + ), + ); - branch_cons.push(branch_con); - } - - // Deviation: elm adds another layer of And nesting - // - // Record the original conditional expression's constraint. - // Each branch's pattern must have the same type - // as the condition expression did. - // - // The return type of each branch must equal the return type of - // the entire when-expression. - branch_cons.push(constraints.equal_types_var( - branch_var, - expected, - Category::When, - region, - )); - branch_constraints.push(constraints.and_constraint(branch_cons)); - } + pattern_cons.push(pattern_con); + branch_cons.push(branch_con); } + // Deviation: elm adds another layer of And nesting + // + // Record the original conditional expression's constraint. + // Each branch's pattern must have the same type + // as the condition expression did. + // + // The return type of each branch must equal the return type of + // the entire when-expression. + // branch_cons.extend(pattern_cons); + // branch_constraints.push(constraints.and_constraint(pattern_cons)); + let mut total_cons = Vec::with_capacity(1 + 2 * branches.len() + 1); + total_cons.push(expr_con); + total_cons.extend(pattern_cons); + total_cons.extend(branch_cons); + total_cons.push(constraints.equal_types_var( + branch_var, + expected, + Category::When, + region, + )); + + let branch_constraints = constraints.and_constraint(total_cons); + // exhautiveness checking happens when converting to mono::Expr - constraints.exists_many([cond_var, *expr_var], branch_constraints) + // ...for now + constraints.exists([cond_var, *expr_var], branch_constraints) } Access { record_var, @@ -1087,6 +1068,8 @@ pub fn constrain_expr( } } +/// Constrain a when branch, returning a pair of constraints (pattern constraint, body constraint). +/// We want to constraint all pattern constraints in a "when" before body constraints. #[inline(always)] fn constrain_when_branch( constraints: &mut Constraints, @@ -1095,7 +1078,7 @@ fn constrain_when_branch( when_branch: &WhenBranch, pattern_expected: PExpected, expr_expected: Expected, -) -> Constraint { +) -> (Constraint, Constraint) { let ret_constraint = constrain_expr( constraints, env, @@ -1123,7 +1106,7 @@ fn constrain_when_branch( ); } - if let Some(loc_guard) = &when_branch.guard { + let (pattern_constraints, body_constraints) = if let Some(loc_guard) = &when_branch.guard { let guard_constraint = constrain_expr( constraints, env, @@ -1140,17 +1123,40 @@ fn constrain_when_branch( let state_constraints = constraints.and_constraint(state.constraints); let inner = constraints.let_constraint([], [], [], guard_constraint, ret_constraint); - constraints.let_constraint([], state.vars, state.headers, state_constraints, inner) + (state_constraints, inner) } else { let state_constraints = constraints.and_constraint(state.constraints); - constraints.let_constraint( - [], - state.vars, - state.headers, - state_constraints, - ret_constraint, - ) - } + (state_constraints, ret_constraint) + }; + + // Our goal is to constrain and introduce variables in all pattern when branch patterns before + // looking at their bodies. + // + // pat1 -> body1 + // *^^^ +~~~~ + // pat2 -> body2 + // *^^^ +~~~~ + // + // * solve first + // + solve second + // + // For a single pattern/body pair, we must introduce variables and symbols defined in the + // pattern before solving the body, since those definitions are effectively let-bound. + // + // But also, we'd like to solve all branch pattern constraints in one swoop before looking at + // the bodies, because the patterns may have presence constraints that expect to be built up + // together. + // + // For this reason, we distinguish the two - and introduce variables in the branch patterns + // as part of the pattern constraint, while only binding those variables during solving of the + // bodies. + let pattern_introduction_constraints = + constraints.let_constraint([], state.vars, [], pattern_constraints, Constraint::True); + + let branch_body_constraints = + constraints.let_constraint([], [], state.headers, Constraint::True, body_constraints); + + (pattern_introduction_constraints, branch_body_constraints) } fn constrain_field( diff --git a/compiler/solve/tests/solve_expr.rs b/compiler/solve/tests/solve_expr.rs index 678175cbb5..c3dbf84746 100644 --- a/compiler/solve/tests/solve_expr.rs +++ b/compiler/solve/tests/solve_expr.rs @@ -5932,4 +5932,21 @@ mod solve_expr { "a -> U64 | a has Hash", ) } + + #[test] + fn when_branch_and_body_flipflop() { + infer_eq_without_problem( + indoc!( + r#" + func = \record -> + when record.tag is + A -> { record & tag: B } + B -> { record & tag: A } + + func + "# + ), + "{ tag : [ A, B ] }a -> { tag : [ A, B ] }a", + ) + } }