compiler: improve return elimination

This commit is contained in:
Alex Badics 2025-02-14 19:00:49 +01:00 committed by Olivier Goffart
parent c2e9fa9f66
commit 38ddb4d2b7

View file

@ -26,19 +26,25 @@ pub fn remove_return(doc: &crate::object_tree::Document) {
visit(e, &mut ret_ty);
let Some(ret_ty) = ret_ty else { return };
let ctx = RemoveReturnContext { ret_ty };
*e = process_expression(std::mem::take(e), &ctx).to_expression(&ctx.ret_ty);
*e = process_expression(std::mem::take(e), true, &ctx).to_expression(&ctx.ret_ty);
})
});
}
fn process_expression(e: Expression, ctx: &RemoveReturnContext) -> ExpressionResult {
fn process_expression(
e: Expression,
toplevel: bool,
ctx: &RemoveReturnContext,
) -> ExpressionResult {
let ty = e.ty();
match e {
Expression::ReturnStatement(expr) => ExpressionResult::Return(expr.map(|e| *e)),
Expression::CodeBlock(expr) => process_codeblock(expr.into_iter().peekable(), &ty, ctx),
Expression::CodeBlock(expr) => {
process_codeblock(expr.into_iter().peekable(), toplevel, &ty, ctx)
}
Expression::Condition { condition, true_expr, false_expr } => {
let te = process_expression(*true_expr, ctx);
let fe = process_expression(*false_expr, ctx);
let te = process_expression(*true_expr, false, ctx);
let fe = process_expression(*false_expr, false, ctx);
match (te, fe) {
(ExpressionResult::Just(te), ExpressionResult::Just(fe)) => {
Expression::Condition { condition, true_expr: te.into(), false_expr: fe.into() }
@ -60,6 +66,13 @@ fn process_expression(e: Expression, ctx: &RemoveReturnContext) -> ExpressionRes
actual_value: cleanup_empty_block(fe),
}
}
(ExpressionResult::Return(te), ExpressionResult::Return(fe)) => {
ExpressionResult::Return(Some(Expression::Condition {
condition: condition.into(),
true_expr: te.unwrap_or(Expression::CodeBlock(vec![])).into(),
false_expr: fe.unwrap_or(Expression::CodeBlock(vec![])).into(),
}))
}
(te, fe) => {
let te = te.into_return_object(&ty, &ctx.ret_ty);
let fe = fe.into_return_object(&ty, &ctx.ret_ty);
@ -75,9 +88,8 @@ fn process_expression(e: Expression, ctx: &RemoveReturnContext) -> ExpressionRes
}
}
}
Expression::Cast { from, to } => {
process_expression(*from, ctx).map_value(|e| Expression::Cast { from: e.into(), to })
}
Expression::Cast { from, to } => process_expression(*from, toplevel, ctx)
.map_value(|e| Expression::Cast { from: e.into(), to }),
e => {
// Normally there shouldn't be any 'return' statements in there since return are not allowed in arbitrary expressions
ExpressionResult::Just(e)
@ -96,12 +108,13 @@ fn cleanup_empty_block(te: Expression) -> Option<Expression> {
fn process_codeblock(
mut iter: std::iter::Peekable<impl Iterator<Item = Expression>>,
toplevel: bool,
ty: &Type,
ctx: &RemoveReturnContext,
) -> ExpressionResult {
let mut stmts = vec![];
while let Some(e) = iter.next() {
match process_expression(e, ctx) {
match process_expression(e, toplevel, ctx) {
ExpressionResult::Just(x) => stmts.push(x),
ExpressionResult::Return(x) => {
stmts.extend(x);
@ -124,6 +137,22 @@ fn process_codeblock(
actual_value,
};
};
if toplevel {
let rest = process_codeblock(iter, true, ty, ctx).to_expression(&ctx.ret_ty);
let mut rest_ex = Expression::CodeBlock(
actual_value.into_iter().chain(core::iter::once(rest)).collect(),
);
if rest_ex.ty() != ctx.ret_ty {
rest_ex =
Expression::Cast { from: Box::new(rest_ex), to: ctx.ret_ty.clone() }
}
return ExpressionResult::MaybeReturn {
pre_statements: stmts,
condition,
returned_value,
actual_value: Some(rest_ex),
};
}
return continue_codeblock(
iter,
ty,
@ -172,7 +201,7 @@ fn continue_codeblock(
has_return_value: bool,
has_value: bool,
) -> ExpressionResult {
let rest = process_codeblock(iter, ty, ctx).into_return_object(ty, &ctx.ret_ty);
let rest = process_codeblock(iter, false, ty, ctx).into_return_object(ty, &ctx.ret_ty);
static COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
let unique_name = format_smolstr!(
"return_check_merge{}",