Constrain function annotation fx to body

This commit is contained in:
Agus Zubiaga 2024-10-10 19:49:48 -03:00
parent b9b85a222f
commit 7af05cc6c9
No known key found for this signature in database
10 changed files with 102 additions and 87 deletions

View file

@ -59,21 +59,32 @@ pub struct Env {
pub resolutions_to_make: Vec<OpportunisticResolve>,
pub home: ModuleId,
/// The enclosing function's fx var to be unified with inner calls
pub fn_fx_var: Option<Variable>,
pub enclosing_fx: Option<EnclosingFx>,
}
#[derive(Clone, Copy)]
pub struct EnclosingFx {
pub fx_var: Variable,
pub ann_region: Option<Region>,
}
impl Env {
pub fn with_fx<F, T>(&mut self, fx_var: Variable, f: F) -> T
pub fn with_enclosing_fx<F, T>(
&mut self,
fx_var: Variable,
ann_region: Option<Region>,
f: F,
) -> T
where
F: FnOnce(&mut Env) -> T,
{
let prev_fx_var = self.fn_fx_var;
let prev = self.enclosing_fx.take();
self.fn_fx_var = Some(fx_var);
self.enclosing_fx = Some(EnclosingFx { fx_var, ann_region });
let result = f(self);
self.fn_fx_var = prev_fx_var;
self.enclosing_fx = prev;
result
}
@ -168,7 +179,7 @@ fn constrain_untyped_closure(
loc_body_expr.region,
));
let ret_constraint = env.with_fx(fx_var, |env| {
let ret_constraint = env.with_enclosing_fx(fx_var, None, |env| {
constrain_expr(
types,
constraints,
@ -590,12 +601,35 @@ pub fn constrain_expr(
let category = Category::CallResult(opt_symbol, *called_via);
let fx_expected_type = match env.enclosing_fx {
Some(enclosing_fn) => {
let enclosing_fx_index = constraints.push_variable(enclosing_fn.fx_var);
constraints.push_expected_type(ForReason(
Reason::CallInFunction(enclosing_fn.ann_region),
enclosing_fx_index,
region,
))
}
None => constraints.push_expected_type(ForReason(
Reason::CallInTopLevelDef,
// top-level defs are only allowed to call pure functions
constraints.push_variable(Variable::PURE),
region,
)),
};
let and_cons = [
fn_con,
constraints.equal_types_var(*fn_var, expected_fn_type, category.clone(), fn_region),
constraints.and_constraint(arg_cons),
constraints.equal_types_var(*ret_var, expected_final_type, category, region),
constraints.call_fx(env.fn_fx_var.unwrap_or(Variable::PURE), *fx_var),
constraints.equal_types_var(
*ret_var,
expected_final_type,
category.clone(),
region,
),
constraints.equal_types_var(*fx_var, fx_expected_type, category, region),
];
let and_constraint = constraints.and_constraint(and_cons);
@ -1990,7 +2024,10 @@ fn constrain_function_def(
home: env.home,
rigids: ftv,
resolutions_to_make: vec![],
fn_fx_var: Some(function_def.fx_type),
enclosing_fx: Some(EnclosingFx {
fx_var: function_def.fx_type,
ann_region: Some(annotation.region),
}),
};
let region = loc_function_def.region;
@ -2110,6 +2147,13 @@ fn constrain_function_def(
let defs_constraint = constraints.and_constraint(argument_pattern_state.constraints);
let cons = [
// Store fx type first so errors are reported at call site
constraints.store(
fx_type_index,
function_def.fx_type,
std::file!(),
std::line!(),
),
constraints.let_constraint(
[],
argument_pattern_state.vars,
@ -2128,12 +2172,6 @@ fn constrain_function_def(
std::file!(),
std::line!(),
),
constraints.store(
fx_type_index,
function_def.fx_type,
std::file!(),
std::line!(),
),
// Now, check the solved function type matches the annotation.
constraints.equal_types(
solved_fn_type,
@ -2242,7 +2280,7 @@ fn constrain_destructure_def(
home: env.home,
rigids: ftv,
resolutions_to_make: vec![],
fn_fx_var: env.fn_fx_var,
enclosing_fx: env.enclosing_fx,
};
let signature_index = constraints.push_type(types, signature);
@ -2345,7 +2383,7 @@ fn constrain_value_def(
home: env.home,
rigids: ftv,
resolutions_to_make: vec![],
fn_fx_var: env.fn_fx_var,
enclosing_fx: env.enclosing_fx,
};
let loc_pattern = Loc::at(loc_symbol.region, Pattern::Identifier(loc_symbol.value));
@ -2633,7 +2671,7 @@ pub fn constrain_decls(
home,
rigids: MutMap::default(),
resolutions_to_make: vec![],
fn_fx_var: None,
enclosing_fx: None,
};
debug_assert_eq!(declarations.declarations.len(), declarations.symbols.len());
@ -2865,7 +2903,7 @@ fn constrain_typed_def(
home: env.home,
resolutions_to_make: vec![],
rigids: ftv,
fn_fx_var: env.fn_fx_var,
enclosing_fx: env.enclosing_fx,
};
let signature_index = constraints.push_type(types, signature);
@ -2974,20 +3012,25 @@ fn constrain_typed_def(
ret_type_index,
));
let ret_constraint = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
);
let ret_constraint = env.with_enclosing_fx(fx_var, Some(annotation.region), |env| {
constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
let ret_constraint = attach_resolution_constraints(constraints, env, ret_constraint);
vars.push(*fn_var);
let defs_constraint = constraints.and_constraint(argument_pattern_state.constraints);
let cons = [
// Store fx type first so errors are reported at call site
constraints.store(fx_type_index, fx_var, std::file!(), std::line!()),
constraints.let_constraint(
[],
argument_pattern_state.vars,
@ -3890,7 +3933,7 @@ fn constraint_recursive_function(
constraints.push_type(types, typ)
};
let expr_con = env.with_fx(fx_var, |env| {
let expr_con = env.with_enclosing_fx(fx_var, Some(annotation.region), |env| {
let expected = constraints.push_expected_type(NoExpectation(ret_type_index));
constrain_expr(
types,
@ -4439,19 +4482,21 @@ fn rec_defs_help(
let typ = types.function(pattern_types, lambda_set, ret_type, fx_type);
constraints.push_type(types, typ)
};
let expr_con = env.with_fx(fx_var, |env| {
let body_type =
constraints.push_expected_type(NoExpectation(ret_type_index));
let expr_con =
env.with_enclosing_fx(fx_var, Some(annotation.region), |env| {
let body_type =
constraints.push_expected_type(NoExpectation(ret_type_index));
constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
let expr_con = attach_resolution_constraints(constraints, env, expr_con);
vars.push(*fn_var);
@ -4460,6 +4505,8 @@ fn rec_defs_help(
constraints.and_constraint(argument_pattern_state.constraints);
let expected_index = constraints.push_expected_type(expected);
let cons = [
// Store fx type first so errors are reported at call site
constraints.store(fx_type_index, fx_var, std::file!(), std::line!()),
constraints.let_constraint(
[],
argument_pattern_state.vars,