Support bound variables in multiple patterns

This commit is contained in:
Ayaz Hafiz 2022-07-21 11:40:09 -04:00
parent bf8fc0d0de
commit ce8b50caea
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
3 changed files with 59 additions and 4 deletions

View file

@ -1824,14 +1824,45 @@ fn constrain_when_branch_help(
for (i, loc_pattern) in when_branch.patterns.iter().enumerate() {
let pattern_expected = pattern_expected(HumanIndex::zero_based(i), loc_pattern.region);
let mut partial_state = PatternState::default();
constrain_pattern(
constraints,
env,
&loc_pattern.value,
loc_pattern.region,
pattern_expected,
&mut state,
&mut partial_state,
);
state.vars.extend(partial_state.vars);
state.constraints.extend(partial_state.constraints);
state
.delayed_is_open_constraints
.extend(partial_state.delayed_is_open_constraints);
if i == 0 {
state.headers.extend(partial_state.headers);
} else {
debug_assert!(
state.headers.keys().all(|sym| partial_state.headers.contains_key(sym)) &&
partial_state.headers.keys().all(|sym| state.headers.contains_key(sym)),
"State and partial state headers differ in bound symbols, should have been caught in canonicalization");
// Make sure the bound variables in the patterns on the same branch agree in their types.
for (sym, typ1) in state.headers.iter() {
let typ2 = partial_state
.headers
.get(sym)
.expect("bound variable in branch not bound in pattern!");
state.constraints.push(constraints.equal_types(
typ1.value.clone(),
Expected::NoExpectation(typ2.value.clone()),
Category::When,
typ2.region,
));
}
}
}
let (pattern_constraints, body_constraints) = if let Some(loc_guard) = &when_branch.guard {

View file

@ -7360,15 +7360,21 @@ mod solve_expr {
#[test]
fn shared_pattern_variable_in_when_branches() {
infer_eq_without_problem(
infer_queries!(
indoc!(
r#"
when A "" is
# ^^^^
A x | B x -> x
C y | D y -> y
# ^ ^ ^
"#
),
"",
@r###"
A "" : [A Str, B Str]
x : Str
x : Str
x : Str
"###
);
}
}

View file

@ -3641,3 +3641,21 @@ fn recursive_call_capturing_function() {
i64
)
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn shared_pattern_variable_in_when_branches() {
assert_evals_to!(
indoc!(
r#"
f = \t ->
when t is
A x | B x -> x
{a: f (A 15u8), b: (B 31u8)}
"#
),
(15u8, 31u8),
(u8, u8)
);
}