From d84e98751efedab41f4d5755add30f2f3797cbb3 Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Mon, 7 Nov 2022 16:49:53 -0600 Subject: [PATCH] Always feed ExpectedTypeIndex to expr constraining --- crates/compiler/constrain/src/builtins.rs | 28 +-- crates/compiler/constrain/src/expr.rs | 244 ++++++++++++---------- 2 files changed, 140 insertions(+), 132 deletions(-) diff --git a/crates/compiler/constrain/src/builtins.rs b/crates/compiler/constrain/src/builtins.rs index 3aee4ab16b..f51ec90b87 100644 --- a/crates/compiler/constrain/src/builtins.rs +++ b/crates/compiler/constrain/src/builtins.rs @@ -1,5 +1,5 @@ use arrayvec::ArrayVec; -use roc_can::constraint::{Constraint, Constraints, TypeOrVar}; +use roc_can::constraint::{Constraint, Constraints, ExpectedTypeIndex}; use roc_can::expected::Expected::{self, *}; use roc_can::num::{FloatBound, FloatWidth, IntBound, IntLitWidth, NumBound, SignDemand}; use roc_module::symbol::Symbol; @@ -71,7 +71,7 @@ pub fn int_literal( constraints: &mut Constraints, num_var: Variable, precision_var: Variable, - expected: Expected, + expected: ExpectedTypeIndex, region: Region, bound: IntBound, ) -> Constraint { @@ -97,10 +97,7 @@ pub fn int_literal( constrs.extend([ constraints.equal_types(num_type_index, expect_precision_var, Category::Int, region), - { - let expected_index = constraints.push_expected_type(expected); - constraints.equal_types(num_type_index, expected_index, Category::Int, region) - }, + constraints.equal_types(num_type_index, expected, Category::Int, region), ]); // TODO the precision_var is not part of the exists here; for float it is. Which is correct? @@ -112,7 +109,7 @@ pub fn single_quote_literal( constraints: &mut Constraints, num_var: Variable, precision_var: Variable, - expected: Expected, + expected: ExpectedTypeIndex, region: Region, bound: SingleQuoteBound, ) -> Constraint { @@ -143,10 +140,7 @@ pub fn single_quote_literal( Category::Character, region, ), - { - let expected_index = constraints.push_expected_type(expected); - constraints.equal_types(num_type_index, expected_index, Category::Character, region) - }, + constraints.equal_types(num_type_index, expected, Category::Character, region), ]); let and_constraint = constraints.and_constraint(constrs); @@ -158,7 +152,7 @@ pub fn float_literal( constraints: &mut Constraints, num_var: Variable, precision_var: Variable, - expected: Expected, + expected: ExpectedTypeIndex, region: Region, bound: FloatBound, ) -> Constraint { @@ -183,10 +177,7 @@ pub fn float_literal( constrs.extend([ constraints.equal_types(num_type_index, expect_precision_var, Category::Frac, region), - { - let expected_index = constraints.push_expected_type(expected); - constraints.equal_types(num_type_index, expected_index, Category::Frac, region) - }, + constraints.equal_types(num_type_index, expected, Category::Frac, region), ]); let and_constraint = constraints.and_constraint(constrs); @@ -197,7 +188,7 @@ pub fn float_literal( pub fn num_literal( constraints: &mut Constraints, num_var: Variable, - expected: Expected, + expected: ExpectedTypeIndex, region: Region, bound: NumBound, ) -> Constraint { @@ -213,8 +204,7 @@ pub fn num_literal( ); let type_index = constraints.push_type(num_type); - let expected_index = constraints.push_expected_type(expected); - constrs.extend([constraints.equal_types(type_index, expected_index, Category::Num, region)]); + constrs.extend([constraints.equal_types(type_index, expected, Category::Num, region)]); let and_constraint = constraints.and_constraint(constrs); constraints.exists([num_var], and_constraint) diff --git a/crates/compiler/constrain/src/expr.rs b/crates/compiler/constrain/src/expr.rs index 09318c75eb..cf7ce315ac 100644 --- a/crates/compiler/constrain/src/expr.rs +++ b/crates/compiler/constrain/src/expr.rs @@ -195,10 +195,15 @@ fn constrain_expr_inner( ) -> Constraint { match expr { &Int(var, precision, _, _, bound) => { + let expected = constraints.push_expected_type(expected); int_literal(constraints, var, precision, expected, region, bound) } - &Num(var, _, _, bound) => num_literal(constraints, var, expected, region, bound), + &Num(var, _, _, bound) => { + let expected = constraints.push_expected_type(expected); + num_literal(constraints, var, expected, region, bound) + } &Float(var, precision, _, _, bound) => { + let expected = constraints.push_expected_type(expected); float_literal(constraints, var, precision, expected, region, bound) } EmptyRecord => constrain_empty_record(constraints, region, expected), @@ -315,14 +320,17 @@ fn constrain_expr_inner( let expected_index = constraints.push_expected_type(expected); constraints.equal_types(str_index, expected_index, Category::Str, region) } - SingleQuote(num_var, precision_var, _, bound) => single_quote_literal( - constraints, - *num_var, - *precision_var, - expected, - region, - *bound, - ), + SingleQuote(num_var, precision_var, _, bound) => { + let expected = constraints.push_expected_type(expected); + single_quote_literal( + constraints, + *num_var, + *precision_var, + expected, + region, + *bound, + ) + } List { elem_var, loc_elems, @@ -343,14 +351,14 @@ fn constrain_expr_inner( let mut list_constraints = Vec::with_capacity(1 + loc_elems.len()); for (index, loc_elem) in loc_elems.iter().enumerate() { - let elem_expected = ForReason( + let elem_expected = constraints.push_expected_type(ForReason( Reason::ElemInList { index: HumanIndex::zero_based(index), }, list_elem_type_index, loc_elem.region, - ); - let constraint = constrain_expr_inner( + )); + let constraint = constrain_expr( constraints, env, loc_elem.region, @@ -386,7 +394,7 @@ fn constrain_expr_inner( let fn_type_index = constraints.push_type(Variable(*fn_var)); let fn_region = loc_fn.region; - let fn_expected = NoExpectation(fn_type_index); + let fn_expected = constraints.push_expected_type(NoExpectation(fn_type_index)); let fn_reason = Reason::FnCall { name: opt_symbol, @@ -394,7 +402,7 @@ fn constrain_expr_inner( }; let fn_con = - constrain_expr_inner(constraints, env, loc_fn.region, &loc_fn.value, fn_expected); + constrain_expr(constraints, env, loc_fn.region, &loc_fn.value, fn_expected); // The function's return type let ret_type = Variable(*ret_var); @@ -421,8 +429,9 @@ fn constrain_expr_inner( name: opt_symbol, arg_index: HumanIndex::zero_based(index), }; - let expected_arg = ForReason(reason, arg_type_index, region); - let arg_con = constrain_expr_inner( + let expected_arg = + constraints.push_expected_type(ForReason(reason, arg_type_index, region)); + let arg_con = constrain_expr( constraints, env, loc_arg.region, @@ -526,10 +535,14 @@ fn constrain_expr_inner( } => { let expected_bool = { let bool_type = constraints.push_type(Type::Variable(Variable::BOOL)); - Expected::ForReason(Reason::ExpectCondition, bool_type, loc_condition.region) + constraints.push_expected_type(Expected::ForReason( + Reason::ExpectCondition, + bool_type, + loc_condition.region, + )) }; - let cond_con = constrain_expr_inner( + let cond_con = constrain_expr( constraints, env, loc_condition.region, @@ -537,7 +550,8 @@ fn constrain_expr_inner( expected_bool, ); - let continuation_con = constrain_expr_inner( + let expected = constraints.push_expected_type(expected); + let continuation_con = constrain_expr( constraints, env, loc_continuation.region, @@ -577,10 +591,14 @@ fn constrain_expr_inner( } => { let expected_bool = { let bool_type = constraints.push_type(Type::Variable(Variable::BOOL)); - Expected::ForReason(Reason::ExpectCondition, bool_type, loc_condition.region) + constraints.push_expected_type(Expected::ForReason( + Reason::ExpectCondition, + bool_type, + loc_condition.region, + )) }; - let cond_con = constrain_expr_inner( + let cond_con = constrain_expr( constraints, env, loc_condition.region, @@ -588,7 +606,8 @@ fn constrain_expr_inner( expected_bool, ); - let continuation_con = constrain_expr_inner( + let expected = constraints.push_expected_type(expected); + let continuation_con = constrain_expr( constraints, env, loc_continuation.region, @@ -629,16 +648,17 @@ fn constrain_expr_inner( } => { let expect_bool = |constraints: &mut Constraints, region| { let bool_type = constraints.push_type(Type::Variable(Variable::BOOL)); - Expected::ForReason(Reason::IfCondition, bool_type, region) + constraints.push_expected_type(Expected::ForReason( + Reason::IfCondition, + bool_type, + region, + )) }; let mut branch_cons = Vec::with_capacity(2 * branches.len() + 3); // TODO why does this cond var exist? is it for error messages? let first_cond_region = branches[0].0.region; - let expected_bool = { - let expected = expect_bool(constraints, first_cond_region); - constraints.push_expected_type(expected) - }; + let expected_bool = expect_bool(constraints, first_cond_region); let cond_var_is_bool_con = constraints.equal_types_var( *cond_var, expected_bool, @@ -653,7 +673,7 @@ fn constrain_expr_inner( let num_branches = branches.len() + 1; for (index, (loc_cond, loc_body)) in branches.iter().enumerate() { let expected_bool = expect_bool(constraints, loc_cond.region); - let cond_con = constrain_expr_inner( + let cond_con = constrain_expr( constraints, env, loc_cond.region, @@ -661,42 +681,45 @@ fn constrain_expr_inner( expected_bool, ); - let then_con = constrain_expr_inner( + let expected_then = constraints.push_expected_type(FromAnnotation( + name.clone(), + arity, + AnnotationSource::TypedIfBranch { + index: HumanIndex::zero_based(index), + num_branches, + region: ann_source.region(), + }, + tipe, + )); + + let then_con = constrain_expr( constraints, env, loc_body.region, &loc_body.value, - FromAnnotation( - name.clone(), - arity, - AnnotationSource::TypedIfBranch { - index: HumanIndex::zero_based(index), - num_branches, - region: ann_source.region(), - }, - tipe, - ), + expected_then, ); branch_cons.push(cond_con); branch_cons.push(then_con); } - let else_con = constrain_expr_inner( + let expected_else = constraints.push_expected_type(FromAnnotation( + name, + arity, + AnnotationSource::TypedIfBranch { + index: HumanIndex::zero_based(branches.len()), + num_branches, + region: ann_source.region(), + }, + tipe, + )); + let else_con = constrain_expr( constraints, env, final_else.region, &final_else.value, - FromAnnotation( - name, - arity, - AnnotationSource::TypedIfBranch { - index: HumanIndex::zero_based(branches.len()), - num_branches, - region: ann_source.region(), - }, - tipe, - ), + expected_else, ); let expected_result_type = constraints.push_expected_type(NoExpectation(tipe)); @@ -718,7 +741,7 @@ fn constrain_expr_inner( for (index, (loc_cond, loc_body)) in branches.iter().enumerate() { let expected_bool = expect_bool(constraints, loc_cond.region); - let cond_con = constrain_expr_inner( + let cond_con = constrain_expr( constraints, env, loc_cond.region, @@ -726,37 +749,39 @@ fn constrain_expr_inner( expected_bool, ); - let then_con = constrain_expr_inner( + let expected_then = constraints.push_expected_type(ForReason( + Reason::IfBranch { + index: HumanIndex::zero_based(index), + total_branches: branches.len(), + }, + branch_var_index, + loc_body.region, + )); + let then_con = constrain_expr( constraints, env, loc_body.region, &loc_body.value, - ForReason( - Reason::IfBranch { - index: HumanIndex::zero_based(index), - total_branches: branches.len(), - }, - branch_var_index, - loc_body.region, - ), + expected_then, ); branch_cons.push(cond_con); branch_cons.push(then_con); } - let else_con = constrain_expr_inner( + let expected_else = constraints.push_expected_type(ForReason( + Reason::IfBranch { + index: HumanIndex::zero_based(branches.len()), + total_branches: branches.len() + 1, + }, + branch_var_index, + final_else.region, + )); + let else_con = constrain_expr( constraints, env, final_else.region, &final_else.value, - ForReason( - Reason::IfBranch { - index: HumanIndex::zero_based(branches.len()), - total_branches: branches.len() + 1, - }, - branch_var_index, - final_else.region, - ), + expected_else, ); let expected = constraints.push_expected_type(expected); @@ -914,12 +939,14 @@ fn constrain_expr_inner( // First, solve the condition type. let real_cond_var = *real_cond_var; let real_cond_type = constraints.push_type(Type::Variable(real_cond_var)); - let cond_constraint = constrain_expr_inner( + let expected_real_cond = + constraints.push_expected_type(Expected::NoExpectation(real_cond_type)); + let cond_constraint = constrain_expr( constraints, env, loc_cond.region, &loc_cond.value, - Expected::NoExpectation(real_cond_type), + expected_real_cond, ); pattern_cons.push(cond_constraint); @@ -993,13 +1020,9 @@ fn constrain_expr_inner( let record_con = constraints.equal_types_var(*record_var, record_expected, category.clone(), region); - let constraint = constrain_expr_inner( - constraints, - env, - region, - &loc_expr.value, - NoExpectation(record_type), - ); + let expected_record = constraints.push_expected_type(NoExpectation(record_type)); + let constraint = + constrain_expr(constraints, env, region, &loc_expr.value, expected_record); let expected = constraints.push_expected_type(expected); let eq = constraints.equal_types_var(field_var, expected, category, region); @@ -1087,13 +1110,9 @@ fn constrain_expr_inner( ) } LetRec(defs, loc_ret, cycle_mark) => { - let body_con = constrain_expr_inner( - constraints, - env, - loc_ret.region, - &loc_ret.value, - expected.clone(), - ); + let expected = constraints.push_expected_type(expected); + let body_con = + constrain_expr(constraints, env, loc_ret.region, &loc_ret.value, expected); constrain_recursive_defs(constraints, env, defs, body_con, *cycle_mark) } @@ -1109,13 +1128,9 @@ fn constrain_expr_inner( loc_ret = new_loc_ret; } - let mut body_con = constrain_expr_inner( - constraints, - env, - loc_ret.region, - &loc_ret.value, - expected.clone(), - ); + let expected = constraints.push_expected_type(expected); + let mut body_con = + constrain_expr(constraints, env, loc_ret.region, &loc_ret.value, expected); while let Some(def) = stack.pop() { body_con = constrain_def(constraints, env, def, body_con) @@ -1137,12 +1152,13 @@ fn constrain_expr_inner( for (var, loc_expr) in arguments { let var_index = constraints.push_type(Variable(*var)); - let arg_con = constrain_expr_inner( + let expected_arg = constraints.push_expected_type(NoExpectation(var_index)); + let arg_con = constrain_expr( constraints, env, loc_expr.region, &loc_expr.value, - Expected::NoExpectation(var_index), + expected_arg, ); arg_cons.push(arg_con); @@ -1226,12 +1242,14 @@ fn constrain_expr_inner( }); // Constrain the argument - let arg_con = constrain_expr_inner( + let expected_arg = + constraints.push_expected_type(Expected::NoExpectation(arg_type_index)); + let arg_con = constrain_expr( constraints, env, arg_loc_expr.region, &arg_loc_expr.value, - Expected::NoExpectation(arg_type_index), + expected_arg, ); // Link the entire wrapped opaque type (with the now-constrained argument) to the @@ -1398,9 +1416,9 @@ fn constrain_expr_inner( op: *op, arg_index: HumanIndex::zero_based(index), }; - let expected_arg = ForReason(reason, arg_type, Region::zero()); - let arg_con = - constrain_expr_inner(constraints, env, Region::zero(), arg, expected_arg); + let expected_arg = + constraints.push_expected_type(ForReason(reason, arg_type, Region::zero())); + let arg_con = constrain_expr(constraints, env, Region::zero(), arg, expected_arg); arg_types.push(arg_type); arg_cons.push(arg_con); @@ -1441,9 +1459,9 @@ fn constrain_expr_inner( foreign_symbol: foreign_symbol.clone(), arg_index: HumanIndex::zero_based(index), }; - let expected_arg = ForReason(reason, arg_type, Region::zero()); - let arg_con = - constrain_expr_inner(constraints, env, Region::zero(), arg, expected_arg); + let expected_arg = + constraints.push_expected_type(ForReason(reason, arg_type, Region::zero())); + let arg_con = constrain_expr(constraints, env, Region::zero(), arg, expected_arg); arg_types.push(arg_type); arg_cons.push(arg_con); @@ -2094,14 +2112,14 @@ fn constrain_when_branch_help( let (pattern_constraints, delayed_is_open_constraints, body_constraints) = if let Some(loc_guard) = &when_branch.guard { let bool_index = constraints.push_type(Variable(Variable::BOOL)); + let expected_guard = constraints.push_expected_type(Expected::ForReason( + Reason::WhenGuard, + bool_index, + loc_guard.region, + )); - let guard_constraint = constrain_expr_inner( - constraints, - env, - region, - &loc_guard.value, - Expected::ForReason(Reason::WhenGuard, bool_index, loc_guard.region), - ); + let guard_constraint = + constrain_expr(constraints, env, region, &loc_guard.value, expected_guard); // must introduce the headers from the pattern before constraining the guard let delayed_is_open_constraints = state.delayed_is_open_constraints; @@ -2135,8 +2153,8 @@ fn constrain_field( loc_expr: &Loc, ) -> (Type, Constraint) { let field_type = constraints.push_type(Variable(field_var)); - let field_expected = NoExpectation(field_type); - let constraint = constrain_expr_inner( + let field_expected = constraints.push_expected_type(NoExpectation(field_type)); + let constraint = constrain_expr( constraints, env, loc_expr.region, @@ -3845,8 +3863,8 @@ fn constrain_field_update( ) -> (Variable, Type, Constraint) { let field_type = constraints.push_type(Variable(var)); let reason = Reason::RecordUpdateValue(field); - let expected = ForReason(reason, field_type, region); - let con = constrain_expr_inner(constraints, env, loc_expr.region, &loc_expr.value, expected); + let expected = constraints.push_expected_type(ForReason(reason, field_type, region)); + let con = constrain_expr(constraints, env, loc_expr.region, &loc_expr.value, expected); (var, Variable(var), con) }