diff --git a/compiler/constrain/src/pattern.rs b/compiler/constrain/src/pattern.rs index 9af7ebc0ae..56d678ac29 100644 --- a/compiler/constrain/src/pattern.rs +++ b/compiler/constrain/src/pattern.rs @@ -230,7 +230,6 @@ pub fn constrain_pattern( RecordField::Required(pat_type) } DestructType::Optional(expr_var, loc_expr) => { - // Eq(Type, Expected, Category, Region), let expr_expected = Expected::ForReason( Reason::RecordDefaultField(label.clone()), pat_type.clone(), diff --git a/compiler/constrain/src/uniq.rs b/compiler/constrain/src/uniq.rs index 073ec5d01d..27e6e8615c 100644 --- a/compiler/constrain/src/uniq.rs +++ b/compiler/constrain/src/uniq.rs @@ -144,7 +144,10 @@ pub struct PatternState { } fn constrain_pattern( + env: &Env, var_store: &mut VarStore, + var_usage: &VarUsage, + applied_usage_constraint: &mut ImSet, state: &mut PatternState, pattern: &Located, expected: PExpected, @@ -246,14 +249,46 @@ fn constrain_pattern( PExpected::NoExpectation(pat_type.clone()), )); state.vars.push(*guard_var); - constrain_pattern(var_store, state, loc_guard, expected); + constrain_pattern( + env, + var_store, + var_usage, + applied_usage_constraint, + state, + loc_guard, + expected, + ); RecordField::Required(pat_type) } - DestructType::Optional(_expr_var, _loc_expr) => { - todo!("Add a constraint for the default value."); + DestructType::Optional(expr_var, loc_expr) => { + let expr_expected = Expected::ForReason( + Reason::RecordDefaultField(label.clone()), + pat_type.clone(), + loc_expr.region, + ); - // RecordField::Optional(pat_type) + state.constraints.push(Constraint::Eq( + Type::Variable(*expr_var), + expr_expected.clone(), + Category::DefaultValue(label.clone()), + region, + )); + + state.vars.push(*expr_var); + + let expr_con = constrain_expr( + env, + var_store, + var_usage, + applied_usage_constraint, + loc_expr.region, + &loc_expr.value, + expr_expected, + ); + state.constraints.push(expr_con); + + RecordField::Optional(pat_type) } DestructType::Required => { // No extra constraints necessary. @@ -317,7 +352,15 @@ fn constrain_pattern( argument_types.push(pattern_type.clone()); let expected = PExpected::NoExpectation(pattern_type); - constrain_pattern(var_store, state, loc_pattern, expected); + constrain_pattern( + env, + var_store, + var_usage, + applied_usage_constraint, + state, + loc_pattern, + expected, + ); } let tag_union_uniq_type = { @@ -673,7 +716,15 @@ pub fn constrain_expr( pattern_types.push(pattern_type); - constrain_pattern(var_store, &mut state, loc_pattern, pattern_expected); + constrain_pattern( + env, + var_store, + var_usage, + applied_usage_constraint, + &mut state, + loc_pattern, + pattern_expected, + ); vars.push(*pattern_var); } @@ -1690,7 +1741,10 @@ fn constrain_when_branch( for loc_pattern in &when_branch.patterns { // mutates the state, so return value is not used constrain_pattern( + env, var_store, + var_usage, + applied_usage_constraint, &mut state, &loc_pattern, pattern_expected.clone(), @@ -1743,7 +1797,10 @@ fn constrain_when_branch( } fn constrain_def_pattern( + env: &Env, var_store: &mut VarStore, + var_usage: &VarUsage, + applied_usage_constraint: &mut ImSet, loc_pattern: &Located, expr_type: Type, ) -> PatternState { @@ -1757,7 +1814,15 @@ fn constrain_def_pattern( constraints: Vec::with_capacity(1), }; - constrain_pattern(var_store, &mut state, loc_pattern, pattern_expected); + constrain_pattern( + env, + var_store, + var_usage, + applied_usage_constraint, + &mut state, + loc_pattern, + pattern_expected, + ); state } @@ -2042,7 +2107,14 @@ fn constrain_def( let expr_var = def.expr_var; let expr_type = Type::Variable(expr_var); - let mut pattern_state = constrain_def_pattern(var_store, &def.loc_pattern, expr_type.clone()); + let mut pattern_state = constrain_def_pattern( + env, + var_store, + var_usage, + applied_usage_constraint, + &def.loc_pattern, + expr_type.clone(), + ); pattern_state.vars.push(expr_var); @@ -2236,7 +2308,10 @@ pub fn rec_defs_help( pattern_state.vars.push(expr_var); constrain_pattern( + env, var_store, + var_usage, + applied_usage_constraint, &mut pattern_state, &def.loc_pattern, pattern_expected, diff --git a/compiler/solve/tests/solve_uniq_expr.rs b/compiler/solve/tests/solve_uniq_expr.rs index 3d47a59a3c..e5776d70b3 100644 --- a/compiler/solve/tests/solve_uniq_expr.rs +++ b/compiler/solve/tests/solve_uniq_expr.rs @@ -3075,4 +3075,44 @@ mod solve_uniq_expr { "Attr * { a : (Attr * { x : (Attr * (Num (Attr * a))), y : (Attr * Float), z : (Attr * c) }), b : (Attr * { blah : (Attr * Str), x : (Attr * (Num (Attr * a))), y : (Attr * Float), z : (Attr * c) }) }" ); } + + #[test] + fn optional_field_function() { + infer_eq( + indoc!( + r#" + \{ x, y ? 0 } -> x + y + "# + ), + "Attr * (Attr (* | b | c) { x : (Attr b (Num (Attr b a))), y ? (Attr c (Num (Attr c a))) }* -> Attr d (Num (Attr d a)))" + ); + } + + #[test] + fn optional_field_let() { + infer_eq( + indoc!( + r#" + { x, y ? 0 } = { x: 32 } + + x + y + "# + ), + "Attr a (Num (Attr a *))", + ); + } + + #[test] + fn optional_field_when() { + infer_eq( + indoc!( + r#" + \r -> + when r is + { x, y ? 0 } -> x + y + "# + ), + "Attr * (Attr (* | b | c) { x : (Attr b (Num (Attr b a))), y ? (Attr c (Num (Attr c a))) }* -> Attr d (Num (Attr d a)))" + ); + } }