Properly type constrain all function types

This commit is contained in:
Sam Mohr 2024-11-21 04:09:11 -08:00
parent 20ba4a92de
commit f857872903
No known key found for this signature in database
GPG key ID: EA41D161A3C1BC99
4 changed files with 320 additions and 79 deletions

View file

@ -168,37 +168,38 @@ fn constrain_untyped_closure(
vars.push(closure_var);
vars.push(fn_var);
let body_type = constraints.push_expected_type(ForReason(
let return_type_index = constraints.push_expected_type(ForReason(
Reason::FunctionOutput,
return_type_index,
loc_body_expr.region,
));
let ret_constraint = env.with_fx_expectation(fx_var, None, |env| {
constrain_expr(
let returns_constraint = env.with_fx_expectation(fx_var, None, |env| {
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
let mut early_return_constraints = Vec::with_capacity(early_returns.len());
for (early_return_variable, early_return_region) in early_returns {
let early_return_var = constraints.push_variable(*early_return_variable);
let early_return_con = constraints.equal_types(
early_return_var,
body_type,
Category::Return,
*early_return_region,
return_type_index,
);
early_return_constraints.push(early_return_con);
}
let mut return_constraints = Vec::with_capacity(early_returns.len() + 1);
return_constraints.push(return_con);
let early_returns_constraint = constraints.and_constraint(early_return_constraints);
for (early_return_variable, early_return_region) in early_returns {
let early_return_con = constraints.equal_types_var(
*early_return_variable,
return_type_index,
Category::Return,
*early_return_region,
);
return_constraints.push(early_return_con);
}
constraints.and_constraint(return_constraints)
});
// make sure the captured symbols are sorted!
debug_assert_eq!(captured_symbols.to_vec(), {
@ -231,7 +232,7 @@ fn constrain_untyped_closure(
pattern_state.vars,
pattern_state.headers,
pattern_state_constraints,
ret_constraint,
returns_constraint,
Generalizable(true),
),
constraints.and_constraint(pattern_state.delayed_fx_suffix_constraints),
@ -242,7 +243,6 @@ fn constrain_untyped_closure(
region,
fn_var,
),
early_returns_constraint,
closure_constraint,
constraints.flex_to_pure(fx_var),
];
@ -1423,23 +1423,20 @@ pub fn constrain_expr(
return_var,
} => {
let return_type_index = constraints.push_variable(*return_var);
let expected_return_value = constraints.push_expected_type(ForReason(
Reason::FunctionOutput,
return_type_index,
return_value.region,
));
let return_con = constrain_expr(
constrain_expr(
types,
constraints,
env,
return_value.region,
&return_value.value,
expected_return_value,
);
constraints.exists([*return_var], return_con)
)
}
Tag {
tag_union_var: variant_var,
@ -2062,21 +2059,6 @@ fn constrain_function_def(
ret_type_index,
));
let mut early_return_constraints = Vec::with_capacity(function_def.early_returns.len());
for (early_return_variable, early_return_region) in &function_def.early_returns {
let early_return_var = constraints.push_variable(*early_return_variable);
let early_return_con = constraints.equal_types(
early_return_var,
return_type_annotation_expected,
Category::Return,
*early_return_region,
);
early_return_constraints.push(early_return_con);
}
let early_returns_constraint = constraints.and_constraint(early_return_constraints);
let solved_fn_type = {
// TODO(types-soa) optimize for Variable
let pattern_types = types.from_old_type_slice(
@ -2090,8 +2072,8 @@ fn constrain_function_def(
constraints.push_type(types, fn_type)
};
let ret_constraint = {
let con = constrain_expr(
let returns_constraint = {
let return_con = constrain_expr(
types,
constraints,
env,
@ -2099,7 +2081,33 @@ fn constrain_function_def(
&loc_body_expr.value,
return_type_annotation_expected,
);
attach_resolution_constraints(constraints, env, con)
let mut return_constraints =
Vec::with_capacity(function_def.early_returns.len() + 1);
return_constraints.push(return_con);
for (early_return_variable, early_return_region) in &function_def.early_returns {
let early_return_type_expected =
constraints.push_expected_type(Expected::ForReason(
Reason::FunctionOutput,
ret_type_index,
*early_return_region,
));
vars.push(*early_return_variable);
let early_return_con = constraints.equal_types_var(
*early_return_variable,
early_return_type_expected,
Category::Return,
*early_return_region,
);
return_constraints.push(early_return_con);
}
let returns_constraint = constraints.and_constraint(return_constraints);
attach_resolution_constraints(constraints, env, returns_constraint)
};
vars.push(expr_var);
@ -2114,13 +2122,12 @@ fn constrain_function_def(
std::file!(),
std::line!(),
),
early_returns_constraint,
constraints.let_constraint(
[],
argument_pattern_state.vars,
argument_pattern_state.headers,
defs_constraint,
ret_constraint,
returns_constraint,
// This is a syntactic function, it can be generalized
Generalizable(true),
),
@ -2876,6 +2883,7 @@ fn constrain_typed_def(
function_type: fn_var,
closure_type: closure_var,
return_type: ret_var,
early_returns,
fx_type: fx_var,
captured_symbols,
arguments,
@ -2945,7 +2953,7 @@ fn constrain_typed_def(
constraints.push_type(types, fn_type)
};
let body_type = constraints.push_expected_type(FromAnnotation(
let return_type = constraints.push_expected_type(FromAnnotation(
def.loc_pattern.clone(),
arguments.len(),
AnnotationSource::TypedBody {
@ -2954,18 +2962,35 @@ fn constrain_typed_def(
ret_type_index,
));
let ret_constraint = env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
let returns_constraint =
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
return_type,
);
let ret_constraint = attach_resolution_constraints(constraints, env, ret_constraint);
let mut return_constraints = Vec::with_capacity(early_returns.len() + 1);
return_constraints.push(return_con);
for (early_return_variable, early_return_region) in early_returns {
let early_return_con = constraints.equal_types_var(
*early_return_variable,
return_type,
Category::Return,
*early_return_region,
);
return_constraints.push(early_return_con);
}
let returns_constraint = constraints.and_constraint(return_constraints);
attach_resolution_constraints(constraints, env, returns_constraint)
});
vars.push(*fn_var);
let defs_constraint = constraints.and_constraint(argument_pattern_state.constraints);
@ -2978,7 +3003,7 @@ fn constrain_typed_def(
argument_pattern_state.vars,
argument_pattern_state.headers,
defs_constraint,
ret_constraint,
returns_constraint,
// This is a syntactic function, it can be generalized
Generalizable(true),
),
@ -3985,18 +4010,38 @@ fn constraint_recursive_function(
constraints.push_type(types, typ)
};
let expr_con = env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let expected = constraints.push_expected_type(NoExpectation(ret_type_index));
constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
expected,
)
});
let expr_con = attach_resolution_constraints(constraints, env, expr_con);
let returns_constraint =
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let expected = constraints.push_expected_type(NoExpectation(ret_type_index));
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
expected,
);
let mut return_constraints =
Vec::with_capacity(function_def.early_returns.len() + 1);
return_constraints.push(return_con);
for (early_return_variable, early_return_region) in &function_def.early_returns
{
let early_return_con = constraints.equal_types_var(
*early_return_variable,
expected,
Category::Return,
*early_return_region,
);
return_constraints.push(early_return_con);
}
let returns_constraint = constraints.and_constraint(return_constraints);
attach_resolution_constraints(constraints, env, returns_constraint)
});
vars.push(expr_var);
@ -4008,7 +4053,7 @@ fn constraint_recursive_function(
argument_pattern_state.vars,
argument_pattern_state.headers,
state_constraints,
expr_con,
returns_constraint,
// Syntactic function can be generalized
Generalizable(true),
),
@ -4470,6 +4515,7 @@ fn rec_defs_help(
function_type: fn_var,
closure_type: closure_var,
return_type: ret_var,
early_returns,
fx_type: fx_var,
captured_symbols,
arguments,
@ -4537,22 +4583,40 @@ fn rec_defs_help(
let typ = types.function(pattern_types, lambda_set, ret_type, fx_type);
constraints.push_type(types, typ)
};
let expr_con =
let returns_constraint =
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let body_type =
let return_type_expected =
constraints.push_expected_type(NoExpectation(ret_type_index));
constrain_expr(
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
return_type_expected,
);
let expr_con = attach_resolution_constraints(constraints, env, expr_con);
let mut return_constraints =
Vec::with_capacity(early_returns.len() + 1);
return_constraints.push(return_con);
for (early_return_variable, early_return_region) in early_returns {
let early_return_con = constraints.equal_types_var(
*early_return_variable,
return_type_expected,
Category::Return,
*early_return_region,
);
return_constraints.push(early_return_con);
}
let returns_constraint =
constraints.and_constraint(return_constraints);
attach_resolution_constraints(constraints, env, returns_constraint)
});
vars.push(*fn_var);
@ -4567,7 +4631,7 @@ fn rec_defs_help(
argument_pattern_state.vars,
argument_pattern_state.headers,
state_constraints,
expr_con,
returns_constraint,
generalizable,
),
// Check argument suffixes against usage

View file

@ -108,6 +108,24 @@ fn early_return_solo() {
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn early_return_solo_annotated() {
assert_evals_to!(
r#"
identity : Str -> Str
identity = \x ->
return x
identity "abc"
"#,
RocStr::from("abc"),
RocStr,
identity,
true
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn early_return_annotated_function() {
@ -149,3 +167,99 @@ fn early_return_annotated_function() {
RocList<RocStr>
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn early_return_nested_annotated_function() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [main] to "./platform"
validateInput : Str -> Result U64 [InvalidNumStr, LessThanFive]
validateInput = \str ->
failIfLessThanFive : U64 -> Result {} [LessThanFive]
failIfLessThanFive = \n ->
if n < 5 then
Err LessThanFive
else
Ok {}
num = try Str.toU64 str
when failIfLessThanFive num is
Err err ->
return Err err
Ok {} ->
Ok num
main : List Str
main =
["abc", "3", "7"]
|> List.map validateInput
|> List.map Inspect.toStr
"#
),
RocList::from_slice(&[
RocStr::from("(Err InvalidNumStr)"),
RocStr::from("(Err LessThanFive)"),
RocStr::from("(Ok 7)")
]),
RocList<RocStr>
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm", feature = "gen-dev"))]
fn early_return_annotated_recursive_function() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [main] to "./platform"
mightCallSecond : U64 -> Result U64 _
mightCallSecond = \num ->
nextNum =
if num < 5 then
return Err LessThanFive
else
num - 1
mightCallFirst nextNum
mightCallFirst : U64 -> Result U64 _
mightCallFirst = \num ->
nextNum =
if num < 10 then
return Err LessThanTen
else
num * 2
if nextNum > 25 then
Ok nextNum
else
mightCallSecond nextNum
main : List Str
main =
[
mightCallSecond 3,
mightCallSecond 7,
mightCallSecond 20,
mightCallFirst 7,
mightCallFirst 15,
]
|> List.map Inspect.toStr
"#
),
RocList::from_slice(&[
RocStr::from("(Err LessThanFive)"),
RocStr::from("(Err LessThanTen)"),
RocStr::from("(Ok 38)"),
RocStr::from("(Err LessThanTen)"),
RocStr::from("(Ok 30)")
]),
RocList<RocStr>
);
}

View file

@ -0,0 +1,48 @@
procedure Bool.11 (#Attr.2, #Attr.3):
let Bool.23 : Int1 = lowlevel Eq #Attr.2 #Attr.3;
ret Bool.23;
procedure Str.26 (Str.83):
let Str.246 : [C {}, C U64] = CallByName Str.66 Str.83;
ret Str.246;
procedure Str.42 (#Attr.2):
let Str.254 : {U64, U8} = lowlevel StrToNum #Attr.2;
ret Str.254;
procedure Str.66 (Str.191):
let Str.192 : {U64, U8} = CallByName Str.42 Str.191;
let Str.252 : U8 = StructAtIndex 1 Str.192;
let Str.253 : U8 = 0i64;
let Str.249 : Int1 = CallByName Bool.11 Str.252 Str.253;
if Str.249 then
let Str.251 : U64 = StructAtIndex 0 Str.192;
let Str.250 : [C {}, C U64] = TagId(1) Str.251;
ret Str.250;
else
let Str.248 : {} = Struct {};
let Str.247 : [C {}, C U64] = TagId(0) Str.248;
ret Str.247;
procedure Test.3 (Test.4):
joinpoint Test.14 Test.5:
let Test.12 : [C {}, C U64] = TagId(1) Test.5;
ret Test.12;
in
let Test.13 : [C {}, C U64] = CallByName Str.26 Test.4;
let Test.18 : U8 = 1i64;
let Test.19 : U8 = GetTagId Test.13;
let Test.20 : Int1 = lowlevel Eq Test.18 Test.19;
if Test.20 then
let Test.6 : U64 = UnionAtIndex (Id 1) (Index 0) Test.13;
jump Test.14 Test.6;
else
let Test.7 : {} = UnionAtIndex (Id 0) (Index 0) Test.13;
let Test.17 : [C {}, C U64] = TagId(0) Test.7;
ret Test.17;
procedure Test.0 ():
let Test.11 : Str = "123";
let Test.10 : [C {}, C U64] = CallByName Test.3 Test.11;
dec Test.11;
ret Test.10;

View file

@ -3695,3 +3695,18 @@ fn dec_refcount_for_usage_after_early_return_in_if() {
"#
)
}
#[mono_test]
fn return_annotated() {
indoc!(
r#"
validateInput : Str -> Result U64 _
validateInput = \str ->
num = try Str.toU64 str
Ok num
validateInput "123"
"#
)
}