Fix tailcalling

This commit is contained in:
Sam Mohr 2024-10-26 06:48:01 -07:00
parent 6a2ffb2f5a
commit a9cd6ac5fa
No known key found for this signature in database
GPG key ID: EA41D161A3C1BC99
4 changed files with 208 additions and 49 deletions

View file

@ -17,6 +17,7 @@ use crate::expr::ClosureData;
use crate::expr::Declarations;
use crate::expr::Expr::{self, *};
use crate::expr::StructAccessorData;
use crate::expr::TailCall;
use crate::expr::{canonicalize_expr, Output, Recursive};
use crate::pattern::{canonicalize_def_header_pattern, BindingsFromPattern, Pattern};
use crate::procedure::QualifiedReference;
@ -134,31 +135,31 @@ impl Annotation {
#[derive(Debug)]
pub(crate) struct CanDefs {
defs: Vec<Option<Def>>,
dbgs: OrderDependentStatements,
expects: OrderDependentStatements,
expects_fx: OrderDependentStatements,
dbgs: ExpectsOrDbgs,
expects: ExpectsOrDbgs,
expects_fx: ExpectsOrDbgs,
def_ordering: DefOrdering,
aliases: VecMap<Symbol, Alias>,
}
#[derive(Clone, Debug)]
pub struct OrderDependentStatements {
pub expressions: Vec<Expr>,
pub struct ExpectsOrDbgs {
pub conditions: Vec<Expr>,
pub regions: Vec<Region>,
pub preceding_comment: Vec<Region>,
}
impl OrderDependentStatements {
impl ExpectsOrDbgs {
fn with_capacity(capacity: usize) -> Self {
Self {
expressions: Vec::with_capacity(capacity),
conditions: Vec::with_capacity(capacity),
regions: Vec::with_capacity(capacity),
preceding_comment: Vec::with_capacity(capacity),
}
}
fn push(&mut self, loc_can_condition: Loc<Expr>, preceding_comment: Region) {
self.expressions.push(loc_can_condition.value);
self.conditions.push(loc_can_condition.value);
self.regions.push(loc_can_condition.region);
self.preceding_comment.push(preceding_comment);
}
@ -304,8 +305,8 @@ pub enum Declaration {
Declare(Def),
DeclareRec(Vec<Def>, IllegalCycleMark),
Builtin(Def),
Expects(OrderDependentStatements),
ExpectsFx(OrderDependentStatements),
Expects(ExpectsOrDbgs),
ExpectsFx(ExpectsOrDbgs),
/// If we know a cycle is illegal during canonicalization.
/// Otherwise we will try to detect this during solving; see [`IllegalCycleMark`].
InvalidCycle(Vec<CycleEntry>),
@ -1235,9 +1236,9 @@ fn canonicalize_value_defs<'a>(
def_ordering.insert_symbol_references(def_id as u32, &temp_output.references)
}
let mut dbgs = OrderDependentStatements::with_capacity(pending_dbgs.len());
let mut expects = OrderDependentStatements::with_capacity(pending_expects.len());
let mut expects_fx = OrderDependentStatements::with_capacity(pending_expects.len());
let mut dbgs = ExpectsOrDbgs::with_capacity(pending_dbgs.len());
let mut expects = ExpectsOrDbgs::with_capacity(pending_expects.len());
let mut expects_fx = ExpectsOrDbgs::with_capacity(pending_expects.len());
for pending in pending_dbgs {
let (loc_can_condition, can_output) = canonicalize_expr(
@ -1712,7 +1713,7 @@ pub(crate) fn sort_top_level_can_defs(
// because of the ordering of declarations, expects should come first because they are
// independent, but can rely on all other top-level symbols in the module
let it = expects
.expressions
.conditions
.into_iter()
.zip(expects.regions)
.zip(expects.preceding_comment);
@ -1725,7 +1726,7 @@ pub(crate) fn sort_top_level_can_defs(
}
let it = expects_fx
.expressions
.conditions
.into_iter()
.zip(expects_fx.regions)
.zip(expects_fx.preceding_comment);
@ -2100,15 +2101,15 @@ pub(crate) fn sort_can_defs(
}
}
if !dbgs.expressions.is_empty() {
if !dbgs.conditions.is_empty() {
declarations.push(Declaration::Expects(dbgs));
}
if !expects.expressions.is_empty() {
if !expects.conditions.is_empty() {
declarations.push(Declaration::Expects(expects));
}
if !expects_fx.expressions.is_empty() {
if !expects_fx.conditions.is_empty() {
declarations.push(Declaration::ExpectsFx(expects_fx));
}
@ -2606,14 +2607,15 @@ fn canonicalize_pending_body<'a>(
// The closure is self tail recursive iff it tail calls itself (by defined name).
let is_recursive = match can_output.tail_call {
Some(tail_symbol) if tail_symbol == *defined_symbol => {
if closure_data.early_returns.is_empty() {
TailCall::NoneMade => Recursive::NotRecursive,
TailCall::Inconsistent => Recursive::Recursive,
TailCall::CallsTo(tail_symbol) => {
if tail_symbol == *defined_symbol {
Recursive::TailRecursive
} else {
Recursive::Recursive
}
}
_ => Recursive::NotRecursive,
};
closure_data.recursive = is_recursive;
@ -2766,7 +2768,7 @@ pub fn report_unused_imports(
}
}
fn decl_to_let<'a>(decl: Declaration, loc_ret: Loc<Expr>) -> Loc<Expr> {
fn decl_to_let(decl: Declaration, loc_ret: Loc<Expr>) -> Loc<Expr> {
match decl {
Declaration::Declare(def) => {
let region = Region::span_across(&def.loc_pattern.region, &loc_ret.region);
@ -2788,7 +2790,7 @@ fn decl_to_let<'a>(decl: Declaration, loc_ret: Loc<Expr>) -> Loc<Expr> {
Declaration::Expects(expects) => {
let mut loc_ret = loc_ret;
let conditions = expects.expressions.into_iter().rev();
let conditions = expects.conditions.into_iter().rev();
let condition_regions = expects.regions.into_iter().rev();
let expect_regions = expects.preceding_comment.into_iter().rev();

View file

@ -1029,11 +1029,10 @@ pub fn desugar_expr<'a>(
}
Return(return_value, after_return) => {
let desugared_return_value = &*env.arena.alloc(desugar_expr(env, scope, return_value));
let desugared_after_return =
after_return.map(|ar| *env.arena.alloc(desugar_expr(env, scope, ar)));
env.arena.alloc(Loc {
value: Return(desugared_return_value, desugared_after_return),
// Do not desugar after_return since it isn't run anyway
value: Return(desugared_return_value, after_return.clone()),
region: loc_expr.region,
})
}

View file

@ -39,20 +39,64 @@ pub type PendingDerives = VecMap<Symbol, (Type, Vec<Loc<Symbol>>)>;
#[derive(Clone, Default, Debug)]
pub struct Output {
pub references: References,
pub tail_call: Option<Symbol>,
pub tail_call: TailCall,
pub introduced_variables: IntroducedVariables,
pub aliases: VecMap<Symbol, Alias>,
pub non_closures: VecSet<Symbol>,
pub pending_derives: PendingDerives,
}
#[derive(Clone, Copy, Default, Debug)]
pub enum TailCall {
#[default]
NoneMade,
CallsTo(Symbol),
Inconsistent,
}
impl TailCall {
pub fn for_expr(expr: &Expr) -> Self {
match expr {
Expr::Call(fn_expr, _, _) => match **fn_expr {
(
_,
Loc {
value: Expr::Var(symbol, _),
..
},
_,
_,
) => Self::CallsTo(symbol),
_ => Self::NoneMade,
},
_ => Self::NoneMade,
}
}
pub fn merge(self, other: Self) -> Self {
match self {
TailCall::NoneMade => other,
TailCall::Inconsistent => TailCall::Inconsistent,
TailCall::CallsTo(our_symbol) => match other {
TailCall::NoneMade => TailCall::CallsTo(our_symbol),
TailCall::Inconsistent => TailCall::Inconsistent,
TailCall::CallsTo(other_symbol) => {
if our_symbol == other_symbol {
TailCall::CallsTo(our_symbol)
} else {
TailCall::Inconsistent
}
}
},
}
}
}
impl Output {
pub fn union(&mut self, other: Self) {
self.references.union_mut(&other.references);
if let (None, Some(later)) = (self.tail_call, other.tail_call) {
self.tail_call = Some(later);
}
self.tail_call = self.tail_call.merge(other.tail_call);
self.introduced_variables
.union_owned(other.introduced_variables);
@ -724,7 +768,7 @@ pub fn canonicalize_expr<'a>(
let output = Output {
references,
tail_call: None,
tail_call: TailCall::NoneMade,
..Default::default()
};
@ -792,7 +836,7 @@ pub fn canonicalize_expr<'a>(
let output = Output {
references,
tail_call: None,
tail_call: TailCall::NoneMade,
..Default::default()
};
@ -908,18 +952,13 @@ pub fn canonicalize_expr<'a>(
output.union(fn_expr_output);
// Default: We're not tail-calling a symbol (by name), we're tail-calling a function value.
output.tail_call = None;
output.tail_call = TailCall::NoneMade;
let expr = match fn_expr.value {
Var(symbol, _) => {
output.references.insert_call(symbol);
// we're tail-calling a symbol by name, check if it's the tail-callable symbol
output.tail_call = match &env.tailcallable_symbol {
Some(tc_sym) if *tc_sym == symbol => Some(symbol),
Some(_) | None => None,
};
output.tail_call = TailCall::CallsTo(symbol);
Call(
Box::new((
@ -1045,7 +1084,7 @@ pub fn canonicalize_expr<'a>(
canonicalize_expr(env, var_store, scope, loc_cond.region, &loc_cond.value);
// the condition can never be a tail-call
output.tail_call = None;
output.tail_call = TailCall::NoneMade;
let mut can_branches = Vec::with_capacity(branches.len());
@ -1070,7 +1109,7 @@ pub fn canonicalize_expr<'a>(
// if code gen mistakenly thinks this is a tail call just because its condition
// happened to be one. (The condition gave us our initial output value.)
if branches.is_empty() {
output.tail_call = None;
output.tail_call = TailCall::NoneMade;
}
// Incorporate all three expressions into a combined Output value.
@ -1271,14 +1310,6 @@ pub fn canonicalize_expr<'a>(
ast::Expr::Return(return_expr, after_return) => {
let mut output = Output::default();
let (loc_return_expr, output1) = canonicalize_expr(
env,
var_store,
scope,
return_expr.region,
&return_expr.value,
);
if let Some(after_return) = after_return {
let region_with_return =
Region::span_across(&return_expr.region, &after_return.region);
@ -1288,8 +1319,18 @@ pub fn canonicalize_expr<'a>(
});
}
let (loc_return_expr, output1) = canonicalize_expr(
env,
var_store,
scope,
return_expr.region,
&return_expr.value,
);
output.union(output1);
output.tail_call = TailCall::for_expr(&loc_return_expr.value);
let return_var = var_store.fresh();
scope.early_returns.push((return_var, return_expr.region));

View file

@ -14498,4 +14498,121 @@ All branches in an `if` must have the same type!
(*)b
"###
);
test_report!(
return_outside_of_function,
indoc!(
r"
someVal =
if 10 > 5 then
x = 5
return x
else
6
someVal + 2
"
),
@r###"
RETURN OUTSIDE OF FUNCTION in /code/proj/Main.roc
This `return` statement doesn't belong to a function:
7 return x
^^^^^^^^
I wouldn't know where to return to if I used it!
"###
);
test_report!(
statements_after_return,
indoc!(
r#"
myFunction = \x ->
if x == 2 then
return x
log! "someData"
useX x 123
else
x + 5
myFunction 2
"#
),
@r###"
UNREACHABLE CODE in /code/proj/Main.roc
This code won't run because it follows a `return` statement:
6> return x
7>
8> log! "someData"
9> useX x 123
Hint: you can move the `return` statement below this block to make the
block run.
"###
);
test_report!(
return_at_end_of_function,
indoc!(
r#"
myFunction = \x ->
y = Num.toStr x
return y
myFunction 3
"#
),
@r###"
UNNECESSARY RETURN in /code/proj/Main.roc
This `return` statement should be an expression instead:
7 return y
^^^^^^^^
In expression-based languages like Roc, the last expression in a
function is treated like a `return` statement. Even though `return` would
work here, just writing an expression is more elegant.
"###
);
test_report!(
mismatch_early_return_with_function_output,
indoc!(
r#"
myFunction = \x ->
if x == 5 then
return "abc"
else
x
myFunction 3
"#
),
@r###"
TYPE MISMATCH in /code/proj/Main.roc
This `return` statement doesn't match the return type of its enclosing
function:
5 if x == 5 then
6> return "abc"
7 else
8 x
It is a `return` statement of type:
Str
But I need every `return` statement in that function to return:
Num *
"###
);
}