Ensure stack frame is always popped when procedure returns from inside a branch

This commit is contained in:
Brian Carroll 2021-09-30 15:49:54 +01:00
parent 79ac2f04b8
commit 7ac7e16f60
4 changed files with 49 additions and 24 deletions

View file

@ -13,7 +13,7 @@ use roc_mono::layout::{Builtin, Layout};
use crate::layout::WasmLayout; use crate::layout::WasmLayout;
use crate::storage::SymbolStorage; use crate::storage::SymbolStorage;
use crate::{ use crate::{
allocate_stack_frame, copy_memory, free_stack_frame, round_up_to_alignment, LocalId, PTR_SIZE, push_stack_frame, copy_memory, pop_stack_frame, round_up_to_alignment, LocalId, PTR_SIZE,
PTR_TYPE, PTR_TYPE,
}; };
@ -93,7 +93,7 @@ impl<'a> WasmBackend<'a> {
} }
pub fn build_proc(&mut self, proc: Proc<'a>, sym: Symbol) -> Result<u32, String> { pub fn build_proc(&mut self, proc: Proc<'a>, sym: Symbol) -> Result<u32, String> {
let signature_builder = self.build_signature(&proc); let signature_builder = self.start_proc(&proc);
self.build_stmt(&proc.body, &proc.ret_layout)?; self.build_stmt(&proc.body, &proc.ret_layout)?;
@ -106,14 +106,17 @@ impl<'a> WasmBackend<'a> {
Ok(function_index) Ok(function_index)
} }
fn build_signature(&mut self, proc: &Proc<'a>) -> SignatureBuilder { fn start_proc(&mut self, proc: &Proc<'a>) -> SignatureBuilder {
let ret_layout = WasmLayout::new(&proc.ret_layout); let ret_layout = WasmLayout::new(&proc.ret_layout);
let signature_builder = if let WasmLayout::StackMemory { .. } = ret_layout { let signature_builder = if let WasmLayout::StackMemory { .. } = ret_layout {
self.arg_types.push(PTR_TYPE); self.arg_types.push(PTR_TYPE);
self.start_block(BlockType::NoResult); // block to ensure all paths pop stack memory (if any)
builder::signature() builder::signature()
} else { } else {
builder::signature().with_result(ret_layout.value_type()) let ret_type = ret_layout.value_type();
self.start_block(BlockType::Value(ret_type)); // block to ensure all paths pop stack memory (if any)
builder::signature().with_result(ret_type)
}; };
for (layout, symbol) in proc.args { for (layout, symbol) in proc.args {
@ -124,10 +127,12 @@ impl<'a> WasmBackend<'a> {
} }
fn finalize_proc(&mut self, signature_builder: SignatureBuilder) -> FunctionDefinition { fn finalize_proc(&mut self, signature_builder: SignatureBuilder) -> FunctionDefinition {
self.end_block(); // end the block from start_proc, to ensure all paths pop stack memory (if any)
let mut final_instructions = Vec::with_capacity(self.instructions.len() + 10); let mut final_instructions = Vec::with_capacity(self.instructions.len() + 10);
if self.stack_memory > 0 { if self.stack_memory > 0 {
allocate_stack_frame( push_stack_frame(
&mut final_instructions, &mut final_instructions,
self.stack_memory, self.stack_memory,
self.stack_frame_pointer.unwrap(), self.stack_frame_pointer.unwrap(),
@ -137,13 +142,13 @@ impl<'a> WasmBackend<'a> {
final_instructions.extend(self.instructions.drain(0..)); final_instructions.extend(self.instructions.drain(0..));
if self.stack_memory > 0 { if self.stack_memory > 0 {
free_stack_frame( pop_stack_frame(
&mut final_instructions, &mut final_instructions,
self.stack_memory, self.stack_memory,
self.stack_frame_pointer.unwrap(), self.stack_frame_pointer.unwrap(),
); );
} }
final_instructions.push(Instruction::End); final_instructions.push(End);
builder::function() builder::function()
.with_signature(signature_builder.build_sig()) .with_signature(signature_builder.build_sig())
@ -275,12 +280,9 @@ impl<'a> WasmBackend<'a> {
self.instructions.push(Loop(BlockType::Value(value_type))); self.instructions.push(Loop(BlockType::Value(value_type)));
} }
fn start_block(&mut self) { fn start_block(&mut self, block_type: BlockType) {
self.block_depth += 1; self.block_depth += 1;
self.instructions.push(Block(block_type));
// Our blocks always end with a `return` or `br`,
// so they never leave extra values on the stack
self.instructions.push(Block(BlockType::NoResult));
} }
fn end_block(&mut self) { fn end_block(&mut self) {
@ -308,7 +310,7 @@ impl<'a> WasmBackend<'a> {
self.symbol_storage_map.insert(*let_sym, storage); self.symbol_storage_map.insert(*let_sym, storage);
} }
self.build_expr(let_sym, expr, layout)?; self.build_expr(let_sym, expr, layout)?;
self.instructions.push(Return); // TODO: branch instead of return so we can clean up stack self.instructions.push(Br(self.block_depth)); // jump to end of function (stack frame pop)
Ok(()) Ok(())
} }
@ -319,7 +321,12 @@ impl<'a> WasmBackend<'a> {
.local_id(); .local_id();
self.build_expr(sym, expr, layout)?; self.build_expr(sym, expr, layout)?;
self.instructions.push(SetLocal(local_id.0));
// If this local is shared with the stack frame pointer, it's already assigned
match self.stack_frame_pointer {
Some(sfp) if sfp == local_id => {}
_ => self.instructions.push(SetLocal(local_id.0))
}
self.build_stmt(following, ret_layout)?; self.build_stmt(following, ret_layout)?;
Ok(()) Ok(())
@ -351,7 +358,7 @@ impl<'a> WasmBackend<'a> {
| VarPrimitive { local_id, .. } | VarPrimitive { local_id, .. }
| VarHeapMemory { local_id, .. } => { | VarHeapMemory { local_id, .. } => {
self.instructions.push(GetLocal(local_id.0)); self.instructions.push(GetLocal(local_id.0));
self.instructions.push(Return); // TODO: branch instead of return so we can clean up stack self.instructions.push(Br(self.block_depth)); // jump to end of function (for stack frame pop)
} }
} }
@ -371,7 +378,7 @@ impl<'a> WasmBackend<'a> {
// create (number_of_branches - 1) new blocks. // create (number_of_branches - 1) new blocks.
for _ in 0..branches.len() { for _ in 0..branches.len() {
self.start_block() self.start_block(BlockType::NoResult)
} }
// the LocalId of the symbol that we match on // the LocalId of the symbol that we match on
@ -422,7 +429,7 @@ impl<'a> WasmBackend<'a> {
jp_parameter_local_ids.push(local_id); jp_parameter_local_ids.push(local_id);
} }
self.start_block(); self.start_block(BlockType::NoResult);
self.joinpoint_label_map self.joinpoint_label_map
.insert(*id, (self.block_depth, jp_parameter_local_ids)); .insert(*id, (self.block_depth, jp_parameter_local_ids));

View file

@ -26,7 +26,7 @@ pub const ALIGN_8: u32 = 3;
pub const STACK_POINTER_GLOBAL_ID: u32 = 0; pub const STACK_POINTER_GLOBAL_ID: u32 = 0;
pub const STACK_ALIGNMENT_BYTES: i32 = 16; pub const STACK_ALIGNMENT_BYTES: i32 = 16;
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct LocalId(pub u32); pub struct LocalId(pub u32);
pub struct Env<'a> { pub struct Env<'a> {
@ -163,7 +163,7 @@ pub fn round_up_to_alignment(unaligned: i32, alignment_bytes: i32) -> i32 {
aligned aligned
} }
pub fn allocate_stack_frame( pub fn push_stack_frame(
instructions: &mut Vec<Instruction>, instructions: &mut Vec<Instruction>,
size: i32, size: i32,
local_frame_pointer: LocalId, local_frame_pointer: LocalId,
@ -178,7 +178,7 @@ pub fn allocate_stack_frame(
]); ]);
} }
pub fn free_stack_frame( pub fn pop_stack_frame(
instructions: &mut Vec<Instruction>, instructions: &mut Vec<Instruction>,
size: i32, size: i32,
local_frame_pointer: LocalId, local_frame_pointer: LocalId,

View file

@ -46,7 +46,7 @@ macro_rules! build_wrapper_body_primitive {
fn build_wrapper_body(main_function_index: u32) -> Vec<Instruction> { fn build_wrapper_body(main_function_index: u32) -> Vec<Instruction> {
let size: i32 = 8; let size: i32 = 8;
let mut instructions = Vec::with_capacity(16); let mut instructions = Vec::with_capacity(16);
allocate_stack_frame(&mut instructions, size, LocalId(STACK_POINTER_LOCAL_ID)); push_stack_frame(&mut instructions, size, LocalId(STACK_POINTER_LOCAL_ID));
instructions.extend([ instructions.extend([
// load result address to prepare for the store instruction later // load result address to prepare for the store instruction later
GetLocal(STACK_POINTER_LOCAL_ID), GetLocal(STACK_POINTER_LOCAL_ID),
@ -60,7 +60,7 @@ macro_rules! build_wrapper_body_primitive {
// Return the result pointer // Return the result pointer
GetLocal(STACK_POINTER_LOCAL_ID), GetLocal(STACK_POINTER_LOCAL_ID),
]); ]);
free_stack_frame(&mut instructions, size, LocalId(STACK_POINTER_LOCAL_ID)); pop_stack_frame(&mut instructions, size, LocalId(STACK_POINTER_LOCAL_ID));
instructions.push(End); instructions.push(End);
instructions instructions
} }
@ -77,7 +77,7 @@ macro_rules! wasm_test_result_primitive {
fn build_wrapper_body_stack_memory(main_function_index: u32, size: usize) -> Vec<Instruction> { fn build_wrapper_body_stack_memory(main_function_index: u32, size: usize) -> Vec<Instruction> {
let mut instructions = Vec::with_capacity(16); let mut instructions = Vec::with_capacity(16);
allocate_stack_frame( push_stack_frame(
&mut instructions, &mut instructions,
size as i32, size as i32,
LocalId(STACK_POINTER_LOCAL_ID), LocalId(STACK_POINTER_LOCAL_ID),
@ -92,7 +92,7 @@ fn build_wrapper_body_stack_memory(main_function_index: u32, size: usize) -> Vec
// Return the result address // Return the result address
GetLocal(STACK_POINTER_LOCAL_ID), GetLocal(STACK_POINTER_LOCAL_ID),
]); ]);
free_stack_frame( pop_stack_frame(
&mut instructions, &mut instructions,
size as i32, size as i32,
LocalId(STACK_POINTER_LOCAL_ID), LocalId(STACK_POINTER_LOCAL_ID),

View file

@ -873,6 +873,24 @@ mod wasm_records {
// ); // );
// } // }
#[test]
fn stack_memory_return_from_branch() {
// stack memory pointer should end up in the right place after returning from a branch
assert_evals_to!(
indoc!(
r#"
stackMemoryJunk = { x: 999, y: 111 }
if True then
{ x: 123, y: 321 }
else
stackMemoryJunk
"#
),
(123, 321),
(i64, i64)
);
}
// #[test] // #[test]
// fn blue_and_present() { // fn blue_and_present() {
// assert_evals_to!( // assert_evals_to!(