basic join point

This commit is contained in:
Folkert 2021-09-08 15:54:00 +02:00
parent 1da32f18e5
commit 4e5b677426
2 changed files with 116 additions and 8 deletions

View file

@ -7,7 +7,7 @@ use parity_wasm::elements::{
use roc_collections::all::MutMap;
use roc_module::low_level::LowLevel;
use roc_module::symbol::Symbol;
use roc_mono::ir::{CallType, Expr, Literal, Proc, Stmt};
use roc_mono::ir::{CallType, Expr, JoinPointId, Literal, Proc, Stmt};
use roc_mono::layout::{Builtin, Layout};
// Don't allocate any constant data at address zero or near it. Would be valid, but bug-prone.
@ -69,7 +69,9 @@ pub struct WasmBackend<'a> {
// Functions: internal state & IR mappings
stack_memory: u32,
symbol_storage_map: MutMap<Symbol, SymbolStorage>,
// joinpoint_label_map: MutMap<JoinPointId, LabelId>,
/// how many blocks deep are we (used for jumps)
block_depth: u32,
joinpoint_label_map: MutMap<JoinPointId, (u32, std::vec::Vec<LocalId>)>,
}
impl<'a> WasmBackend<'a> {
@ -92,7 +94,8 @@ impl<'a> WasmBackend<'a> {
// Functions: internal state & IR mappings
stack_memory: 0,
symbol_storage_map: MutMap::default(),
// joinpoint_label_map: MutMap::default(),
block_depth: 0,
joinpoint_label_map: MutMap::default(),
}
}
@ -182,6 +185,33 @@ impl<'a> WasmBackend<'a> {
Ok(())
}
fn start_loop(&mut self) {
self.block_depth += 1;
// self.instructions.push(Loop(BlockType::NoResult));
self.instructions
.push(Loop(BlockType::Value(ValueType::I64)));
}
fn end_loop(&mut self) {
self.block_depth -= 1;
self.instructions.push(End);
}
fn start_block(&mut self) {
self.block_depth += 1;
// Our blocks always end with a `return` or `br`,
// so they never leave extra values on the stack
self.instructions.push(Block(BlockType::NoResult));
}
fn end_block(&mut self) {
self.block_depth -= 1;
self.instructions.push(End);
}
fn build_stmt(&mut self, stmt: &Stmt<'a>, ret_layout: &Layout<'a>) -> Result<(), String> {
match stmt {
// This pattern is a simple optimisation to get rid of one local and two instructions per proc.
@ -228,11 +258,8 @@ impl<'a> WasmBackend<'a> {
// or `BrTable`
// create (number_of_branches - 1) new blocks.
//
// Every branch ends in a `return`,
// so the block leaves no values on the stack
for _ in 0..branches.len() {
self.instructions.push(Block(BlockType::NoResult));
self.start_block()
}
// the LocalId of the symbol that we match on
@ -262,13 +289,64 @@ impl<'a> WasmBackend<'a> {
// (the first branch would have broken out of 1 block,
// hence we must generate its code first)
for (_, _, branch) in branches.iter() {
self.instructions.push(End);
self.end_block();
self.build_stmt(branch, ret_layout)?;
}
Ok(())
}
Stmt::Join {
id,
parameters,
body,
remainder,
} => {
// make locals for join pointer parameters
let mut local_ids = std::vec::Vec::with_capacity(parameters.len());
for parameter in parameters.iter() {
let wasm_layout = WasmLayout::new(&parameter.layout)?;
let local_id = self.insert_local(wasm_layout, parameter.symbol);
local_ids.push(local_id);
}
self.start_block();
self.joinpoint_label_map
.insert(*id, (self.block_depth, local_ids.clone()));
self.build_stmt(remainder, ret_layout)?;
self.end_block();
self.start_loop();
self.build_stmt(body, ret_layout)?;
self.end_loop();
Ok(())
}
Stmt::Jump(id, arguments) => {
let (target, locals) = &self.joinpoint_label_map[id];
// put the arguments on the stack
for (symbol, local_id) in arguments.iter().zip(locals.iter()) {
let argument = match self.symbol_storage_map.get(symbol) {
Some(SymbolStorage(local_id, _)) => local_id.0,
None => unreachable!("symbol not defined: {:?}", symbol),
};
self.instructions.push(GetLocal(argument));
self.instructions.push(SetLocal(local_id.0));
}
// jump
let levels = self.block_depth - target;
self.instructions.push(Br(levels));
Ok(())
}
x => Err(format!("statement not yet implemented: {:?}", x)),
}
}

View file

@ -116,6 +116,36 @@ mod dev_num {
);
}
#[test]
fn join_point() {
assert_evals_to!(
indoc!(
r#"
x = if True then 111 else 222
x + 123
"#
),
234,
i64
);
}
#[test]
#[ignore]
fn factorial() {
assert_evals_to!(
indoc!(
r#"
fac = \n ->
if n
"#
),
234,
i64
);
}
// #[test]
// fn gen_add_f64() {
// assert_evals_to!(