Solve all when branch pattern constraints before solving their bodies

Closes #2886
This commit is contained in:
Ayaz Hafiz 2022-04-18 11:03:25 -04:00
parent e0c9931326
commit 2856a38236
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
2 changed files with 127 additions and 104 deletions

View file

@ -596,66 +596,42 @@ pub fn constrain_expr(
NoExpectation(cond_type.clone()), NoExpectation(cond_type.clone()),
); );
let mut branch_constraints = Vec::with_capacity(branches.len() + 1); let branch_var = *expr_var;
branch_constraints.push(expr_con); let branch_type = Variable(branch_var);
match &expected { let branch_expr_reason =
|expected: &Expected<Type>, index, branch_region| match expected {
FromAnnotation(name, arity, ann_source, _typ) => { FromAnnotation(name, arity, ann_source, _typ) => {
// NOTE deviation from elm. // NOTE deviation from elm.
// //
// in elm, `_typ` is used, but because we have this `expr_var` too // in elm, `_typ` is used, but because we have this `expr_var` too
// and need to constrain it, this is what works and gives better error messages // and need to constrain it, this is what works and gives better error messages
let typ = Type::Variable(*expr_var);
for (index, when_branch) in branches.iter().enumerate() {
let pattern_region =
Region::across_all(when_branch.patterns.iter().map(|v| &v.region));
let branch_con = constrain_when_branch(
constraints,
env,
when_branch.value.region,
when_branch,
PExpected::ForReason(
PReason::WhenMatch {
index: HumanIndex::zero_based(index),
},
cond_type.clone(),
pattern_region,
),
FromAnnotation( FromAnnotation(
name.clone(), name.clone(),
*arity, *arity,
AnnotationSource::TypedWhenBranch { AnnotationSource::TypedWhenBranch {
index: HumanIndex::zero_based(index), index,
region: ann_source.region(), region: ann_source.region(),
}, },
typ.clone(), branch_type.clone(),
)
}
_ => ForReason(
Reason::WhenBranch { index },
branch_type.clone(),
branch_region,
), ),
); };
branch_constraints.push(branch_con);
}
branch_constraints.push(constraints.equal_types_var(
*expr_var,
expected,
Category::When,
region,
));
return constraints.exists_many([cond_var, *expr_var], branch_constraints);
}
_ => {
let branch_var = *expr_var;
let branch_type = Variable(branch_var);
let mut branch_cons = Vec::with_capacity(branches.len()); let mut branch_cons = Vec::with_capacity(branches.len());
let mut pattern_cons = Vec::with_capacity(branches.len());
for (index, when_branch) in branches.iter().enumerate() { for (index, when_branch) in branches.iter().enumerate() {
let pattern_region = let pattern_region =
Region::across_all(when_branch.patterns.iter().map(|v| &v.region)); Region::across_all(when_branch.patterns.iter().map(|v| &v.region));
let branch_con = constrain_when_branch(
let (pattern_con, branch_con) = constrain_when_branch(
constraints, constraints,
env, env,
region, region,
@ -667,15 +643,14 @@ pub fn constrain_expr(
cond_type.clone(), cond_type.clone(),
pattern_region, pattern_region,
), ),
ForReason( branch_expr_reason(
Reason::WhenBranch { &expected,
index: HumanIndex::zero_based(index), HumanIndex::zero_based(index),
},
branch_type.clone(),
when_branch.value.region, when_branch.value.region,
), ),
); );
pattern_cons.push(pattern_con);
branch_cons.push(branch_con); branch_cons.push(branch_con);
} }
@ -687,18 +662,24 @@ pub fn constrain_expr(
// //
// The return type of each branch must equal the return type of // The return type of each branch must equal the return type of
// the entire when-expression. // the entire when-expression.
branch_cons.push(constraints.equal_types_var( // branch_cons.extend(pattern_cons);
// branch_constraints.push(constraints.and_constraint(pattern_cons));
let mut total_cons = Vec::with_capacity(1 + 2 * branches.len() + 1);
total_cons.push(expr_con);
total_cons.extend(pattern_cons);
total_cons.extend(branch_cons);
total_cons.push(constraints.equal_types_var(
branch_var, branch_var,
expected, expected,
Category::When, Category::When,
region, region,
)); ));
branch_constraints.push(constraints.and_constraint(branch_cons));
} let branch_constraints = constraints.and_constraint(total_cons);
}
// exhautiveness checking happens when converting to mono::Expr // exhautiveness checking happens when converting to mono::Expr
constraints.exists_many([cond_var, *expr_var], branch_constraints) // ...for now
constraints.exists([cond_var, *expr_var], branch_constraints)
} }
Access { Access {
record_var, record_var,
@ -1087,6 +1068,8 @@ pub fn constrain_expr(
} }
} }
/// Constrain a when branch, returning a pair of constraints (pattern constraint, body constraint).
/// We want to constraint all pattern constraints in a "when" before body constraints.
#[inline(always)] #[inline(always)]
fn constrain_when_branch( fn constrain_when_branch(
constraints: &mut Constraints, constraints: &mut Constraints,
@ -1095,7 +1078,7 @@ fn constrain_when_branch(
when_branch: &WhenBranch, when_branch: &WhenBranch,
pattern_expected: PExpected<Type>, pattern_expected: PExpected<Type>,
expr_expected: Expected<Type>, expr_expected: Expected<Type>,
) -> Constraint { ) -> (Constraint, Constraint) {
let ret_constraint = constrain_expr( let ret_constraint = constrain_expr(
constraints, constraints,
env, env,
@ -1123,7 +1106,7 @@ fn constrain_when_branch(
); );
} }
if let Some(loc_guard) = &when_branch.guard { let (pattern_constraints, body_constraints) = if let Some(loc_guard) = &when_branch.guard {
let guard_constraint = constrain_expr( let guard_constraint = constrain_expr(
constraints, constraints,
env, env,
@ -1140,17 +1123,40 @@ fn constrain_when_branch(
let state_constraints = constraints.and_constraint(state.constraints); let state_constraints = constraints.and_constraint(state.constraints);
let inner = constraints.let_constraint([], [], [], guard_constraint, ret_constraint); let inner = constraints.let_constraint([], [], [], guard_constraint, ret_constraint);
constraints.let_constraint([], state.vars, state.headers, state_constraints, inner) (state_constraints, inner)
} else { } else {
let state_constraints = constraints.and_constraint(state.constraints); let state_constraints = constraints.and_constraint(state.constraints);
constraints.let_constraint( (state_constraints, ret_constraint)
[], };
state.vars,
state.headers, // Our goal is to constrain and introduce variables in all pattern when branch patterns before
state_constraints, // looking at their bodies.
ret_constraint, //
) // pat1 -> body1
} // *^^^ +~~~~
// pat2 -> body2
// *^^^ +~~~~
//
// * solve first
// + solve second
//
// For a single pattern/body pair, we must introduce variables and symbols defined in the
// pattern before solving the body, since those definitions are effectively let-bound.
//
// But also, we'd like to solve all branch pattern constraints in one swoop before looking at
// the bodies, because the patterns may have presence constraints that expect to be built up
// together.
//
// For this reason, we distinguish the two - and introduce variables in the branch patterns
// as part of the pattern constraint, while only binding those variables during solving of the
// bodies.
let pattern_introduction_constraints =
constraints.let_constraint([], state.vars, [], pattern_constraints, Constraint::True);
let branch_body_constraints =
constraints.let_constraint([], [], state.headers, Constraint::True, body_constraints);
(pattern_introduction_constraints, branch_body_constraints)
} }
fn constrain_field( fn constrain_field(

View file

@ -5932,4 +5932,21 @@ mod solve_expr {
"a -> U64 | a has Hash", "a -> U64 | a has Hash",
) )
} }
#[test]
fn when_branch_and_body_flipflop() {
infer_eq_without_problem(
indoc!(
r#"
func = \record ->
when record.tag is
A -> { record & tag: B }
B -> { record & tag: A }
func
"#
),
"{ tag : [ A, B ] }a -> { tag : [ A, B ] }a",
)
}
} }