mirror of
https://github.com/roc-lang/roc.git
synced 2025-08-03 11:52:19 +00:00
Properly type constrain all function types
This commit is contained in:
parent
20ba4a92de
commit
f857872903
4 changed files with 320 additions and 79 deletions
|
@ -168,37 +168,38 @@ fn constrain_untyped_closure(
|
|||
vars.push(closure_var);
|
||||
vars.push(fn_var);
|
||||
|
||||
let body_type = constraints.push_expected_type(ForReason(
|
||||
let return_type_index = constraints.push_expected_type(ForReason(
|
||||
Reason::FunctionOutput,
|
||||
return_type_index,
|
||||
loc_body_expr.region,
|
||||
));
|
||||
|
||||
let ret_constraint = env.with_fx_expectation(fx_var, None, |env| {
|
||||
constrain_expr(
|
||||
let returns_constraint = env.with_fx_expectation(fx_var, None, |env| {
|
||||
let return_con = constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
loc_body_expr.region,
|
||||
&loc_body_expr.value,
|
||||
body_type,
|
||||
)
|
||||
});
|
||||
|
||||
let mut early_return_constraints = Vec::with_capacity(early_returns.len());
|
||||
for (early_return_variable, early_return_region) in early_returns {
|
||||
let early_return_var = constraints.push_variable(*early_return_variable);
|
||||
let early_return_con = constraints.equal_types(
|
||||
early_return_var,
|
||||
body_type,
|
||||
Category::Return,
|
||||
*early_return_region,
|
||||
return_type_index,
|
||||
);
|
||||
|
||||
early_return_constraints.push(early_return_con);
|
||||
}
|
||||
let mut return_constraints = Vec::with_capacity(early_returns.len() + 1);
|
||||
return_constraints.push(return_con);
|
||||
|
||||
let early_returns_constraint = constraints.and_constraint(early_return_constraints);
|
||||
for (early_return_variable, early_return_region) in early_returns {
|
||||
let early_return_con = constraints.equal_types_var(
|
||||
*early_return_variable,
|
||||
return_type_index,
|
||||
Category::Return,
|
||||
*early_return_region,
|
||||
);
|
||||
|
||||
return_constraints.push(early_return_con);
|
||||
}
|
||||
|
||||
constraints.and_constraint(return_constraints)
|
||||
});
|
||||
|
||||
// make sure the captured symbols are sorted!
|
||||
debug_assert_eq!(captured_symbols.to_vec(), {
|
||||
|
@ -231,7 +232,7 @@ fn constrain_untyped_closure(
|
|||
pattern_state.vars,
|
||||
pattern_state.headers,
|
||||
pattern_state_constraints,
|
||||
ret_constraint,
|
||||
returns_constraint,
|
||||
Generalizable(true),
|
||||
),
|
||||
constraints.and_constraint(pattern_state.delayed_fx_suffix_constraints),
|
||||
|
@ -242,7 +243,6 @@ fn constrain_untyped_closure(
|
|||
region,
|
||||
fn_var,
|
||||
),
|
||||
early_returns_constraint,
|
||||
closure_constraint,
|
||||
constraints.flex_to_pure(fx_var),
|
||||
];
|
||||
|
@ -1423,23 +1423,20 @@ pub fn constrain_expr(
|
|||
return_var,
|
||||
} => {
|
||||
let return_type_index = constraints.push_variable(*return_var);
|
||||
|
||||
let expected_return_value = constraints.push_expected_type(ForReason(
|
||||
Reason::FunctionOutput,
|
||||
return_type_index,
|
||||
return_value.region,
|
||||
));
|
||||
|
||||
let return_con = constrain_expr(
|
||||
constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
return_value.region,
|
||||
&return_value.value,
|
||||
expected_return_value,
|
||||
);
|
||||
|
||||
constraints.exists([*return_var], return_con)
|
||||
)
|
||||
}
|
||||
Tag {
|
||||
tag_union_var: variant_var,
|
||||
|
@ -2062,21 +2059,6 @@ fn constrain_function_def(
|
|||
ret_type_index,
|
||||
));
|
||||
|
||||
let mut early_return_constraints = Vec::with_capacity(function_def.early_returns.len());
|
||||
for (early_return_variable, early_return_region) in &function_def.early_returns {
|
||||
let early_return_var = constraints.push_variable(*early_return_variable);
|
||||
let early_return_con = constraints.equal_types(
|
||||
early_return_var,
|
||||
return_type_annotation_expected,
|
||||
Category::Return,
|
||||
*early_return_region,
|
||||
);
|
||||
|
||||
early_return_constraints.push(early_return_con);
|
||||
}
|
||||
|
||||
let early_returns_constraint = constraints.and_constraint(early_return_constraints);
|
||||
|
||||
let solved_fn_type = {
|
||||
// TODO(types-soa) optimize for Variable
|
||||
let pattern_types = types.from_old_type_slice(
|
||||
|
@ -2090,8 +2072,8 @@ fn constrain_function_def(
|
|||
constraints.push_type(types, fn_type)
|
||||
};
|
||||
|
||||
let ret_constraint = {
|
||||
let con = constrain_expr(
|
||||
let returns_constraint = {
|
||||
let return_con = constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
|
@ -2099,7 +2081,33 @@ fn constrain_function_def(
|
|||
&loc_body_expr.value,
|
||||
return_type_annotation_expected,
|
||||
);
|
||||
attach_resolution_constraints(constraints, env, con)
|
||||
|
||||
let mut return_constraints =
|
||||
Vec::with_capacity(function_def.early_returns.len() + 1);
|
||||
return_constraints.push(return_con);
|
||||
|
||||
for (early_return_variable, early_return_region) in &function_def.early_returns {
|
||||
let early_return_type_expected =
|
||||
constraints.push_expected_type(Expected::ForReason(
|
||||
Reason::FunctionOutput,
|
||||
ret_type_index,
|
||||
*early_return_region,
|
||||
));
|
||||
|
||||
vars.push(*early_return_variable);
|
||||
let early_return_con = constraints.equal_types_var(
|
||||
*early_return_variable,
|
||||
early_return_type_expected,
|
||||
Category::Return,
|
||||
*early_return_region,
|
||||
);
|
||||
|
||||
return_constraints.push(early_return_con);
|
||||
}
|
||||
|
||||
let returns_constraint = constraints.and_constraint(return_constraints);
|
||||
|
||||
attach_resolution_constraints(constraints, env, returns_constraint)
|
||||
};
|
||||
|
||||
vars.push(expr_var);
|
||||
|
@ -2114,13 +2122,12 @@ fn constrain_function_def(
|
|||
std::file!(),
|
||||
std::line!(),
|
||||
),
|
||||
early_returns_constraint,
|
||||
constraints.let_constraint(
|
||||
[],
|
||||
argument_pattern_state.vars,
|
||||
argument_pattern_state.headers,
|
||||
defs_constraint,
|
||||
ret_constraint,
|
||||
returns_constraint,
|
||||
// This is a syntactic function, it can be generalized
|
||||
Generalizable(true),
|
||||
),
|
||||
|
@ -2876,6 +2883,7 @@ fn constrain_typed_def(
|
|||
function_type: fn_var,
|
||||
closure_type: closure_var,
|
||||
return_type: ret_var,
|
||||
early_returns,
|
||||
fx_type: fx_var,
|
||||
captured_symbols,
|
||||
arguments,
|
||||
|
@ -2945,7 +2953,7 @@ fn constrain_typed_def(
|
|||
constraints.push_type(types, fn_type)
|
||||
};
|
||||
|
||||
let body_type = constraints.push_expected_type(FromAnnotation(
|
||||
let return_type = constraints.push_expected_type(FromAnnotation(
|
||||
def.loc_pattern.clone(),
|
||||
arguments.len(),
|
||||
AnnotationSource::TypedBody {
|
||||
|
@ -2954,18 +2962,35 @@ fn constrain_typed_def(
|
|||
ret_type_index,
|
||||
));
|
||||
|
||||
let ret_constraint = env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
|
||||
constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
loc_body_expr.region,
|
||||
&loc_body_expr.value,
|
||||
body_type,
|
||||
)
|
||||
});
|
||||
let returns_constraint =
|
||||
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
|
||||
let return_con = constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
loc_body_expr.region,
|
||||
&loc_body_expr.value,
|
||||
return_type,
|
||||
);
|
||||
|
||||
let ret_constraint = attach_resolution_constraints(constraints, env, ret_constraint);
|
||||
let mut return_constraints = Vec::with_capacity(early_returns.len() + 1);
|
||||
return_constraints.push(return_con);
|
||||
|
||||
for (early_return_variable, early_return_region) in early_returns {
|
||||
let early_return_con = constraints.equal_types_var(
|
||||
*early_return_variable,
|
||||
return_type,
|
||||
Category::Return,
|
||||
*early_return_region,
|
||||
);
|
||||
|
||||
return_constraints.push(early_return_con);
|
||||
}
|
||||
|
||||
let returns_constraint = constraints.and_constraint(return_constraints);
|
||||
|
||||
attach_resolution_constraints(constraints, env, returns_constraint)
|
||||
});
|
||||
|
||||
vars.push(*fn_var);
|
||||
let defs_constraint = constraints.and_constraint(argument_pattern_state.constraints);
|
||||
|
@ -2978,7 +3003,7 @@ fn constrain_typed_def(
|
|||
argument_pattern_state.vars,
|
||||
argument_pattern_state.headers,
|
||||
defs_constraint,
|
||||
ret_constraint,
|
||||
returns_constraint,
|
||||
// This is a syntactic function, it can be generalized
|
||||
Generalizable(true),
|
||||
),
|
||||
|
@ -3985,18 +4010,38 @@ fn constraint_recursive_function(
|
|||
constraints.push_type(types, typ)
|
||||
};
|
||||
|
||||
let expr_con = env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
|
||||
let expected = constraints.push_expected_type(NoExpectation(ret_type_index));
|
||||
constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
loc_body_expr.region,
|
||||
&loc_body_expr.value,
|
||||
expected,
|
||||
)
|
||||
});
|
||||
let expr_con = attach_resolution_constraints(constraints, env, expr_con);
|
||||
let returns_constraint =
|
||||
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
|
||||
let expected = constraints.push_expected_type(NoExpectation(ret_type_index));
|
||||
let return_con = constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
loc_body_expr.region,
|
||||
&loc_body_expr.value,
|
||||
expected,
|
||||
);
|
||||
|
||||
let mut return_constraints =
|
||||
Vec::with_capacity(function_def.early_returns.len() + 1);
|
||||
return_constraints.push(return_con);
|
||||
|
||||
for (early_return_variable, early_return_region) in &function_def.early_returns
|
||||
{
|
||||
let early_return_con = constraints.equal_types_var(
|
||||
*early_return_variable,
|
||||
expected,
|
||||
Category::Return,
|
||||
*early_return_region,
|
||||
);
|
||||
|
||||
return_constraints.push(early_return_con);
|
||||
}
|
||||
|
||||
let returns_constraint = constraints.and_constraint(return_constraints);
|
||||
|
||||
attach_resolution_constraints(constraints, env, returns_constraint)
|
||||
});
|
||||
|
||||
vars.push(expr_var);
|
||||
|
||||
|
@ -4008,7 +4053,7 @@ fn constraint_recursive_function(
|
|||
argument_pattern_state.vars,
|
||||
argument_pattern_state.headers,
|
||||
state_constraints,
|
||||
expr_con,
|
||||
returns_constraint,
|
||||
// Syntactic function can be generalized
|
||||
Generalizable(true),
|
||||
),
|
||||
|
@ -4470,6 +4515,7 @@ fn rec_defs_help(
|
|||
function_type: fn_var,
|
||||
closure_type: closure_var,
|
||||
return_type: ret_var,
|
||||
early_returns,
|
||||
fx_type: fx_var,
|
||||
captured_symbols,
|
||||
arguments,
|
||||
|
@ -4537,22 +4583,40 @@ fn rec_defs_help(
|
|||
let typ = types.function(pattern_types, lambda_set, ret_type, fx_type);
|
||||
constraints.push_type(types, typ)
|
||||
};
|
||||
let expr_con =
|
||||
let returns_constraint =
|
||||
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
|
||||
let body_type =
|
||||
let return_type_expected =
|
||||
constraints.push_expected_type(NoExpectation(ret_type_index));
|
||||
|
||||
constrain_expr(
|
||||
let return_con = constrain_expr(
|
||||
types,
|
||||
constraints,
|
||||
env,
|
||||
loc_body_expr.region,
|
||||
&loc_body_expr.value,
|
||||
body_type,
|
||||
)
|
||||
});
|
||||
return_type_expected,
|
||||
);
|
||||
|
||||
let expr_con = attach_resolution_constraints(constraints, env, expr_con);
|
||||
let mut return_constraints =
|
||||
Vec::with_capacity(early_returns.len() + 1);
|
||||
return_constraints.push(return_con);
|
||||
|
||||
for (early_return_variable, early_return_region) in early_returns {
|
||||
let early_return_con = constraints.equal_types_var(
|
||||
*early_return_variable,
|
||||
return_type_expected,
|
||||
Category::Return,
|
||||
*early_return_region,
|
||||
);
|
||||
|
||||
return_constraints.push(early_return_con);
|
||||
}
|
||||
|
||||
let returns_constraint =
|
||||
constraints.and_constraint(return_constraints);
|
||||
|
||||
attach_resolution_constraints(constraints, env, returns_constraint)
|
||||
});
|
||||
|
||||
vars.push(*fn_var);
|
||||
|
||||
|
@ -4567,7 +4631,7 @@ fn rec_defs_help(
|
|||
argument_pattern_state.vars,
|
||||
argument_pattern_state.headers,
|
||||
state_constraints,
|
||||
expr_con,
|
||||
returns_constraint,
|
||||
generalizable,
|
||||
),
|
||||
// Check argument suffixes against usage
|
||||
|
|
|
@ -108,6 +108,24 @@ fn early_return_solo() {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn early_return_solo_annotated() {
|
||||
assert_evals_to!(
|
||||
r#"
|
||||
identity : Str -> Str
|
||||
identity = \x ->
|
||||
return x
|
||||
|
||||
identity "abc"
|
||||
"#,
|
||||
RocStr::from("abc"),
|
||||
RocStr,
|
||||
identity,
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
|
||||
fn early_return_annotated_function() {
|
||||
|
@ -149,3 +167,99 @@ fn early_return_annotated_function() {
|
|||
RocList<RocStr>
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
|
||||
fn early_return_nested_annotated_function() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
app "test" provides [main] to "./platform"
|
||||
|
||||
validateInput : Str -> Result U64 [InvalidNumStr, LessThanFive]
|
||||
validateInput = \str ->
|
||||
failIfLessThanFive : U64 -> Result {} [LessThanFive]
|
||||
failIfLessThanFive = \n ->
|
||||
if n < 5 then
|
||||
Err LessThanFive
|
||||
else
|
||||
Ok {}
|
||||
|
||||
num = try Str.toU64 str
|
||||
|
||||
when failIfLessThanFive num is
|
||||
Err err ->
|
||||
return Err err
|
||||
|
||||
Ok {} ->
|
||||
Ok num
|
||||
|
||||
main : List Str
|
||||
main =
|
||||
["abc", "3", "7"]
|
||||
|> List.map validateInput
|
||||
|> List.map Inspect.toStr
|
||||
"#
|
||||
),
|
||||
RocList::from_slice(&[
|
||||
RocStr::from("(Err InvalidNumStr)"),
|
||||
RocStr::from("(Err LessThanFive)"),
|
||||
RocStr::from("(Ok 7)")
|
||||
]),
|
||||
RocList<RocStr>
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
|
||||
fn early_return_annotated_recursive_function() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
app "test" provides [main] to "./platform"
|
||||
|
||||
mightCallSecond : U64 -> Result U64 _
|
||||
mightCallSecond = \num ->
|
||||
nextNum =
|
||||
if num < 5 then
|
||||
return Err LessThanFive
|
||||
else
|
||||
num - 1
|
||||
|
||||
mightCallFirst nextNum
|
||||
|
||||
mightCallFirst : U64 -> Result U64 _
|
||||
mightCallFirst = \num ->
|
||||
nextNum =
|
||||
if num < 10 then
|
||||
return Err LessThanTen
|
||||
else
|
||||
num * 2
|
||||
|
||||
if nextNum > 25 then
|
||||
Ok nextNum
|
||||
else
|
||||
mightCallSecond nextNum
|
||||
|
||||
main : List Str
|
||||
main =
|
||||
[
|
||||
mightCallSecond 3,
|
||||
mightCallSecond 7,
|
||||
mightCallSecond 20,
|
||||
mightCallFirst 7,
|
||||
mightCallFirst 15,
|
||||
]
|
||||
|> List.map Inspect.toStr
|
||||
"#
|
||||
),
|
||||
RocList::from_slice(&[
|
||||
RocStr::from("(Err LessThanFive)"),
|
||||
RocStr::from("(Err LessThanTen)"),
|
||||
RocStr::from("(Ok 38)"),
|
||||
RocStr::from("(Err LessThanTen)"),
|
||||
RocStr::from("(Ok 30)")
|
||||
]),
|
||||
RocList<RocStr>
|
||||
);
|
||||
}
|
||||
|
|
48
crates/compiler/test_mono/generated/return_annotated.txt
generated
Normal file
48
crates/compiler/test_mono/generated/return_annotated.txt
generated
Normal file
|
@ -0,0 +1,48 @@
|
|||
procedure Bool.11 (#Attr.2, #Attr.3):
|
||||
let Bool.23 : Int1 = lowlevel Eq #Attr.2 #Attr.3;
|
||||
ret Bool.23;
|
||||
|
||||
procedure Str.26 (Str.83):
|
||||
let Str.246 : [C {}, C U64] = CallByName Str.66 Str.83;
|
||||
ret Str.246;
|
||||
|
||||
procedure Str.42 (#Attr.2):
|
||||
let Str.254 : {U64, U8} = lowlevel StrToNum #Attr.2;
|
||||
ret Str.254;
|
||||
|
||||
procedure Str.66 (Str.191):
|
||||
let Str.192 : {U64, U8} = CallByName Str.42 Str.191;
|
||||
let Str.252 : U8 = StructAtIndex 1 Str.192;
|
||||
let Str.253 : U8 = 0i64;
|
||||
let Str.249 : Int1 = CallByName Bool.11 Str.252 Str.253;
|
||||
if Str.249 then
|
||||
let Str.251 : U64 = StructAtIndex 0 Str.192;
|
||||
let Str.250 : [C {}, C U64] = TagId(1) Str.251;
|
||||
ret Str.250;
|
||||
else
|
||||
let Str.248 : {} = Struct {};
|
||||
let Str.247 : [C {}, C U64] = TagId(0) Str.248;
|
||||
ret Str.247;
|
||||
|
||||
procedure Test.3 (Test.4):
|
||||
joinpoint Test.14 Test.5:
|
||||
let Test.12 : [C {}, C U64] = TagId(1) Test.5;
|
||||
ret Test.12;
|
||||
in
|
||||
let Test.13 : [C {}, C U64] = CallByName Str.26 Test.4;
|
||||
let Test.18 : U8 = 1i64;
|
||||
let Test.19 : U8 = GetTagId Test.13;
|
||||
let Test.20 : Int1 = lowlevel Eq Test.18 Test.19;
|
||||
if Test.20 then
|
||||
let Test.6 : U64 = UnionAtIndex (Id 1) (Index 0) Test.13;
|
||||
jump Test.14 Test.6;
|
||||
else
|
||||
let Test.7 : {} = UnionAtIndex (Id 0) (Index 0) Test.13;
|
||||
let Test.17 : [C {}, C U64] = TagId(0) Test.7;
|
||||
ret Test.17;
|
||||
|
||||
procedure Test.0 ():
|
||||
let Test.11 : Str = "123";
|
||||
let Test.10 : [C {}, C U64] = CallByName Test.3 Test.11;
|
||||
dec Test.11;
|
||||
ret Test.10;
|
|
@ -3695,3 +3695,18 @@ fn dec_refcount_for_usage_after_early_return_in_if() {
|
|||
"#
|
||||
)
|
||||
}
|
||||
|
||||
#[mono_test]
|
||||
fn return_annotated() {
|
||||
indoc!(
|
||||
r#"
|
||||
validateInput : Str -> Result U64 _
|
||||
validateInput = \str ->
|
||||
num = try Str.toU64 str
|
||||
|
||||
Ok num
|
||||
|
||||
validateInput "123"
|
||||
"#
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue