Ensure stack frame is always popped when procedure returns from inside a branch

This commit is contained in:
Brian Carroll 2021-09-30 15:49:54 +01:00
parent 79ac2f04b8
commit 7ac7e16f60
4 changed files with 49 additions and 24 deletions

View file

@ -13,7 +13,7 @@ use roc_mono::layout::{Builtin, Layout};
use crate::layout::WasmLayout;
use crate::storage::SymbolStorage;
use crate::{
allocate_stack_frame, copy_memory, free_stack_frame, round_up_to_alignment, LocalId, PTR_SIZE,
push_stack_frame, copy_memory, pop_stack_frame, round_up_to_alignment, LocalId, PTR_SIZE,
PTR_TYPE,
};
@ -93,7 +93,7 @@ impl<'a> WasmBackend<'a> {
}
pub fn build_proc(&mut self, proc: Proc<'a>, sym: Symbol) -> Result<u32, String> {
let signature_builder = self.build_signature(&proc);
let signature_builder = self.start_proc(&proc);
self.build_stmt(&proc.body, &proc.ret_layout)?;
@ -106,14 +106,17 @@ impl<'a> WasmBackend<'a> {
Ok(function_index)
}
fn build_signature(&mut self, proc: &Proc<'a>) -> SignatureBuilder {
fn start_proc(&mut self, proc: &Proc<'a>) -> SignatureBuilder {
let ret_layout = WasmLayout::new(&proc.ret_layout);
let signature_builder = if let WasmLayout::StackMemory { .. } = ret_layout {
self.arg_types.push(PTR_TYPE);
self.start_block(BlockType::NoResult); // block to ensure all paths pop stack memory (if any)
builder::signature()
} else {
builder::signature().with_result(ret_layout.value_type())
let ret_type = ret_layout.value_type();
self.start_block(BlockType::Value(ret_type)); // block to ensure all paths pop stack memory (if any)
builder::signature().with_result(ret_type)
};
for (layout, symbol) in proc.args {
@ -124,10 +127,12 @@ impl<'a> WasmBackend<'a> {
}
fn finalize_proc(&mut self, signature_builder: SignatureBuilder) -> FunctionDefinition {
self.end_block(); // end the block from start_proc, to ensure all paths pop stack memory (if any)
let mut final_instructions = Vec::with_capacity(self.instructions.len() + 10);
if self.stack_memory > 0 {
allocate_stack_frame(
push_stack_frame(
&mut final_instructions,
self.stack_memory,
self.stack_frame_pointer.unwrap(),
@ -137,13 +142,13 @@ impl<'a> WasmBackend<'a> {
final_instructions.extend(self.instructions.drain(0..));
if self.stack_memory > 0 {
free_stack_frame(
pop_stack_frame(
&mut final_instructions,
self.stack_memory,
self.stack_frame_pointer.unwrap(),
);
}
final_instructions.push(Instruction::End);
final_instructions.push(End);
builder::function()
.with_signature(signature_builder.build_sig())
@ -275,12 +280,9 @@ impl<'a> WasmBackend<'a> {
self.instructions.push(Loop(BlockType::Value(value_type)));
}
fn start_block(&mut self) {
fn start_block(&mut self, block_type: BlockType) {
self.block_depth += 1;
// Our blocks always end with a `return` or `br`,
// so they never leave extra values on the stack
self.instructions.push(Block(BlockType::NoResult));
self.instructions.push(Block(block_type));
}
fn end_block(&mut self) {
@ -308,7 +310,7 @@ impl<'a> WasmBackend<'a> {
self.symbol_storage_map.insert(*let_sym, storage);
}
self.build_expr(let_sym, expr, layout)?;
self.instructions.push(Return); // TODO: branch instead of return so we can clean up stack
self.instructions.push(Br(self.block_depth)); // jump to end of function (stack frame pop)
Ok(())
}
@ -319,7 +321,12 @@ impl<'a> WasmBackend<'a> {
.local_id();
self.build_expr(sym, expr, layout)?;
self.instructions.push(SetLocal(local_id.0));
// If this local is shared with the stack frame pointer, it's already assigned
match self.stack_frame_pointer {
Some(sfp) if sfp == local_id => {}
_ => self.instructions.push(SetLocal(local_id.0))
}
self.build_stmt(following, ret_layout)?;
Ok(())
@ -351,7 +358,7 @@ impl<'a> WasmBackend<'a> {
| VarPrimitive { local_id, .. }
| VarHeapMemory { local_id, .. } => {
self.instructions.push(GetLocal(local_id.0));
self.instructions.push(Return); // TODO: branch instead of return so we can clean up stack
self.instructions.push(Br(self.block_depth)); // jump to end of function (for stack frame pop)
}
}
@ -371,7 +378,7 @@ impl<'a> WasmBackend<'a> {
// create (number_of_branches - 1) new blocks.
for _ in 0..branches.len() {
self.start_block()
self.start_block(BlockType::NoResult)
}
// the LocalId of the symbol that we match on
@ -422,7 +429,7 @@ impl<'a> WasmBackend<'a> {
jp_parameter_local_ids.push(local_id);
}
self.start_block();
self.start_block(BlockType::NoResult);
self.joinpoint_label_map
.insert(*id, (self.block_depth, jp_parameter_local_ids));