Merge pull request #1682 from rtfeldman/gen_wasm_join

joinpoints for the wasm backend
This commit is contained in:
Folkert de Vries 2021-09-10 16:23:34 +02:00 committed by GitHub
commit 575aaa1f0b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 209 additions and 117 deletions

View file

@ -7,7 +7,7 @@ use parity_wasm::elements::{
use roc_collections::all::MutMap; use roc_collections::all::MutMap;
use roc_module::low_level::LowLevel; use roc_module::low_level::LowLevel;
use roc_module::symbol::Symbol; use roc_module::symbol::Symbol;
use roc_mono::ir::{CallType, Expr, Literal, Proc, Stmt}; use roc_mono::ir::{CallType, Expr, JoinPointId, Literal, Proc, Stmt};
use roc_mono::layout::{Builtin, Layout}; use roc_mono::layout::{Builtin, Layout};
// Don't allocate any constant data at address zero or near it. Would be valid, but bug-prone. // Don't allocate any constant data at address zero or near it. Would be valid, but bug-prone.
@ -23,7 +23,7 @@ struct LabelId(u32);
#[derive(Debug)] #[derive(Debug)]
struct SymbolStorage(LocalId, WasmLayout); struct SymbolStorage(LocalId, WasmLayout);
#[derive(Debug)] #[derive(Clone, Copy, Debug)]
struct WasmLayout { struct WasmLayout {
value_type: ValueType, value_type: ValueType,
stack_memory: u32, stack_memory: u32,
@ -69,7 +69,9 @@ pub struct WasmBackend<'a> {
// Functions: internal state & IR mappings // Functions: internal state & IR mappings
stack_memory: u32, stack_memory: u32,
symbol_storage_map: MutMap<Symbol, SymbolStorage>, symbol_storage_map: MutMap<Symbol, SymbolStorage>,
// joinpoint_label_map: MutMap<JoinPointId, LabelId>, /// how many blocks deep are we (used for jumps)
block_depth: u32,
joinpoint_label_map: MutMap<JoinPointId, (u32, std::vec::Vec<LocalId>)>,
} }
impl<'a> WasmBackend<'a> { impl<'a> WasmBackend<'a> {
@ -92,7 +94,8 @@ impl<'a> WasmBackend<'a> {
// Functions: internal state & IR mappings // Functions: internal state & IR mappings
stack_memory: 0, stack_memory: 0,
symbol_storage_map: MutMap::default(), symbol_storage_map: MutMap::default(),
// joinpoint_label_map: MutMap::default(), block_depth: 0,
joinpoint_label_map: MutMap::default(),
} }
} }
@ -182,6 +185,27 @@ impl<'a> WasmBackend<'a> {
Ok(()) Ok(())
} }
/// start a loop that leaves a value on the stack
fn start_loop_with_return(&mut self, value_type: ValueType) {
self.block_depth += 1;
// self.instructions.push(Loop(BlockType::NoResult));
self.instructions.push(Loop(BlockType::Value(value_type)));
}
fn start_block(&mut self) {
self.block_depth += 1;
// Our blocks always end with a `return` or `br`,
// so they never leave extra values on the stack
self.instructions.push(Block(BlockType::NoResult));
}
fn end_block(&mut self) {
self.block_depth -= 1;
self.instructions.push(End);
}
fn build_stmt(&mut self, stmt: &Stmt<'a>, ret_layout: &Layout<'a>) -> Result<(), String> { fn build_stmt(&mut self, stmt: &Stmt<'a>, ret_layout: &Layout<'a>) -> Result<(), String> {
match stmt { match stmt {
// This pattern is a simple optimisation to get rid of one local and two instructions per proc. // This pattern is a simple optimisation to get rid of one local and two instructions per proc.
@ -228,11 +252,8 @@ impl<'a> WasmBackend<'a> {
// or `BrTable` // or `BrTable`
// create (number_of_branches - 1) new blocks. // create (number_of_branches - 1) new blocks.
//
// Every branch ends in a `return`,
// so the block leaves no values on the stack
for _ in 0..branches.len() { for _ in 0..branches.len() {
self.instructions.push(Block(BlockType::NoResult)); self.start_block()
} }
// the LocalId of the symbol that we match on // the LocalId of the symbol that we match on
@ -262,13 +283,69 @@ impl<'a> WasmBackend<'a> {
// (the first branch would have broken out of 1 block, // (the first branch would have broken out of 1 block,
// hence we must generate its code first) // hence we must generate its code first)
for (_, _, branch) in branches.iter() { for (_, _, branch) in branches.iter() {
self.instructions.push(End); self.end_block();
self.build_stmt(branch, ret_layout)?; self.build_stmt(branch, ret_layout)?;
} }
Ok(()) Ok(())
} }
Stmt::Join {
id,
parameters,
body,
remainder,
} => {
// make locals for join pointer parameters
let mut jp_parameter_local_ids = std::vec::Vec::with_capacity(parameters.len());
for parameter in parameters.iter() {
let wasm_layout = WasmLayout::new(&parameter.layout)?;
let local_id = self.insert_local(wasm_layout, parameter.symbol);
jp_parameter_local_ids.push(local_id);
}
self.start_block();
self.joinpoint_label_map
.insert(*id, (self.block_depth, jp_parameter_local_ids));
self.build_stmt(remainder, ret_layout)?;
self.end_block();
// A `return` inside of a `loop` seems to make it so that the `loop` itself
// also "returns" (so, leaves on the stack) a value of the return type.
let return_wasm_layout = WasmLayout::new(ret_layout)?;
self.start_loop_with_return(return_wasm_layout.value_type);
self.build_stmt(body, ret_layout)?;
// ends the loop
self.end_block();
Ok(())
}
Stmt::Jump(id, arguments) => {
let (target, locals) = &self.joinpoint_label_map[id];
// put the arguments on the stack
for (symbol, local_id) in arguments.iter().zip(locals.iter()) {
let argument = match self.symbol_storage_map.get(symbol) {
Some(SymbolStorage(local_id, _)) => local_id.0,
None => unreachable!("symbol not defined: {:?}", symbol),
};
self.instructions.push(GetLocal(argument));
self.instructions.push(SetLocal(local_id.0));
}
// jump
let levels = self.block_depth - target;
self.instructions.push(Br(levels));
Ok(())
}
x => Err(format!("statement not yet implemented: {:?}", x)), x => Err(format!("statement not yet implemented: {:?}", x)),
} }
} }
@ -280,7 +357,7 @@ impl<'a> WasmBackend<'a> {
layout: &Layout<'a>, layout: &Layout<'a>,
) -> Result<(), String> { ) -> Result<(), String> {
match expr { match expr {
Expr::Literal(lit) => self.load_literal(lit), Expr::Literal(lit) => self.load_literal(lit, layout),
Expr::Call(roc_mono::ir::Call { Expr::Call(roc_mono::ir::Call {
call_type, call_type,
@ -308,7 +385,7 @@ impl<'a> WasmBackend<'a> {
} }
} }
fn load_literal(&mut self, lit: &Literal<'a>) -> Result<(), String> { fn load_literal(&mut self, lit: &Literal<'a>, layout: &Layout<'a>) -> Result<(), String> {
match lit { match lit {
Literal::Bool(x) => { Literal::Bool(x) => {
self.instructions.push(I32Const(*x as i32)); self.instructions.push(I32Const(*x as i32));
@ -319,7 +396,15 @@ impl<'a> WasmBackend<'a> {
Ok(()) Ok(())
} }
Literal::Int(x) => { Literal::Int(x) => {
self.instructions.push(I64Const(*x as i64)); match layout {
Layout::Builtin(Builtin::Int32) => {
self.instructions.push(I32Const(*x as i32));
}
Layout::Builtin(Builtin::Int64) => {
self.instructions.push(I64Const(*x as i64));
}
x => panic!("loading literal, {:?}, is not yet implemented", x),
}
Ok(()) Ok(())
} }
Literal::Float(x) => { Literal::Float(x) => {
@ -363,6 +448,22 @@ impl<'a> WasmBackend<'a> {
ValueType::F32 => &[F32Add], ValueType::F32 => &[F32Add],
ValueType::F64 => &[F64Add], ValueType::F64 => &[F64Add],
}, },
LowLevel::NumSub => match return_value_type {
ValueType::I32 => &[I32Sub],
ValueType::I64 => &[I64Sub],
ValueType::F32 => &[F32Sub],
ValueType::F64 => &[F64Sub],
},
LowLevel::NumMul => match return_value_type {
ValueType::I32 => &[I32Mul],
ValueType::I64 => &[I64Mul],
ValueType::F32 => &[F32Mul],
ValueType::F64 => &[F64Mul],
},
LowLevel::NumGt => {
// needs layout of the argument to be implemented fully
&[I32GtS]
}
_ => { _ => {
return Err(format!("unsupported low-level op {:?}", lowlevel)); return Err(format!("unsupported low-level op {:?}", lowlevel));
} }

View file

@ -25,6 +25,19 @@ pub fn build_module<'a>(
let mut backend = WasmBackend::new(); let mut backend = WasmBackend::new();
let mut layout_ids = LayoutIds::default(); let mut layout_ids = LayoutIds::default();
// Sort procedures by occurrence order
//
// We sort by the "name", but those are interned strings, and the name that is
// interned first will have a lower number.
//
// But, the name that occurs first is always `main` because it is in the (implicit)
// file header. Therefore sorting high to low will put other functions before main
//
// This means that for now other functions in the file have to be ordered "in reverse": if A
// uses B, then the name of A must first occur after the first occurrence of the name of B
let mut procedures: std::vec::Vec<_> = procedures.into_iter().collect();
procedures.sort_by(|a, b| b.0 .0.cmp(&a.0 .0));
for ((sym, layout), proc) in procedures { for ((sym, layout), proc) in procedures {
let function_index = backend.build_proc(proc, sym)?; let function_index = backend.build_proc(proc, sym)?;
if env.exposed_to_host.contains(&sym) { if env.exposed_to_host.contains(&sym) {

View file

@ -116,6 +116,44 @@ mod dev_num {
); );
} }
#[test]
fn join_point() {
assert_evals_to!(
indoc!(
r#"
x = if True then 111 else 222
x + 123
"#
),
234,
i64
);
}
#[test]
fn factorial() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [ main ] to "./platform"
fac : I32, I32 -> I32
fac = \n, accum ->
if n > 1 then
fac (n - 1) (n * accum)
else
accum
main : I32
main = fac 8 1
"#
),
40_320,
i32
);
}
// #[test] // #[test]
// fn gen_add_f64() { // fn gen_add_f64() {
// assert_evals_to!( // assert_evals_to!(

View file

@ -3927,7 +3927,7 @@ fn make_specializations<'a>(
); );
let external_specializations_requested = procs.externals_we_need.clone(); let external_specializations_requested = procs.externals_we_need.clone();
let procedures = procs.get_specialized_procs_without_rc(mono_env.arena); let procedures = procs.get_specialized_procs_without_rc(&mut mono_env);
let make_specializations_end = SystemTime::now(); let make_specializations_end = SystemTime::now();
module_timing.make_specializations = make_specializations_end module_timing.make_specializations = make_specializations_end

View file

@ -272,6 +272,33 @@ impl<'a> Proc<'a> {
proc.body = b.clone(); proc.body = b.clone();
} }
} }
fn make_tail_recursive(&mut self, env: &mut Env<'a, '_>) {
let mut args = Vec::with_capacity_in(self.args.len(), env.arena);
let mut proc_args = Vec::with_capacity_in(self.args.len(), env.arena);
for (layout, symbol) in self.args {
let new = env.unique_symbol();
args.push((*layout, *symbol, new));
proc_args.push((*layout, new));
}
use self::SelfRecursive::*;
if let SelfRecursive(id) = self.is_self_recursive {
let transformed = crate::tail_recursion::make_tail_recursive(
env.arena,
id,
self.name,
self.body.clone(),
args.into_bump_slice(),
);
if let Some(with_tco) = transformed {
self.body = with_tco;
self.args = proc_args.into_bump_slice();
}
}
}
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -350,7 +377,7 @@ pub enum InProgressProc<'a> {
impl<'a> Procs<'a> { impl<'a> Procs<'a> {
pub fn get_specialized_procs_without_rc( pub fn get_specialized_procs_without_rc(
self, self,
arena: &'a Bump, env: &mut Env<'a, '_>,
) -> MutMap<(Symbol, ProcLayout<'a>), Proc<'a>> { ) -> MutMap<(Symbol, ProcLayout<'a>), Proc<'a>> {
let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher()); let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
@ -376,16 +403,7 @@ impl<'a> Procs<'a> {
panic!(); panic!();
} }
Done(mut proc) => { Done(mut proc) => {
use self::SelfRecursive::*; proc.make_tail_recursive(env);
if let SelfRecursive(id) = proc.is_self_recursive {
proc.body = crate::tail_recursion::make_tail_recursive(
arena,
id,
proc.name,
proc.body.clone(),
proc.args,
);
}
result.insert(key, proc); result.insert(key, proc);
} }
@ -395,86 +413,6 @@ impl<'a> Procs<'a> {
result result
} }
// TODO investigate make this an iterator?
pub fn get_specialized_procs(
self,
arena: &'a Bump,
) -> MutMap<(Symbol, ProcLayout<'a>), Proc<'a>> {
let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
for ((s, toplevel), in_prog_proc) in self.specialized.into_iter() {
match in_prog_proc {
InProgress => unreachable!(
"The procedure {:?} should have be done by now",
(s, toplevel)
),
Done(proc) => {
result.insert((s, toplevel), proc);
}
}
}
for (_, proc) in result.iter_mut() {
use self::SelfRecursive::*;
if let SelfRecursive(id) = proc.is_self_recursive {
proc.body = crate::tail_recursion::make_tail_recursive(
arena,
id,
proc.name,
proc.body.clone(),
proc.args,
);
}
}
let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result));
crate::inc_dec::visit_procs(arena, borrow_params, &mut result);
result
}
pub fn get_specialized_procs_help(
self,
arena: &'a Bump,
) -> (
MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>,
&'a crate::borrow::ParamMap<'a>,
) {
let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
for ((s, toplevel), in_prog_proc) in self.specialized.into_iter() {
match in_prog_proc {
InProgress => unreachable!(
"The procedure {:?} should have be done by now",
(s, toplevel)
),
Done(proc) => {
result.insert((s, toplevel), proc);
}
}
}
for (_, proc) in result.iter_mut() {
use self::SelfRecursive::*;
if let SelfRecursive(id) = proc.is_self_recursive {
proc.body = crate::tail_recursion::make_tail_recursive(
arena,
id,
proc.name,
proc.body.clone(),
proc.args,
);
}
}
let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result));
crate::inc_dec::visit_procs(arena, borrow_params, &mut result);
(result, borrow_params)
}
// TODO trim down these arguments! // TODO trim down these arguments!
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn insert_named( pub fn insert_named(

View file

@ -33,16 +33,16 @@ pub fn make_tail_recursive<'a>(
id: JoinPointId, id: JoinPointId,
needle: Symbol, needle: Symbol,
stmt: Stmt<'a>, stmt: Stmt<'a>,
args: &'a [(Layout<'a>, Symbol)], args: &'a [(Layout<'a>, Symbol, Symbol)],
) -> Stmt<'a> { ) -> Option<Stmt<'a>> {
let allocated = arena.alloc(stmt); let allocated = arena.alloc(stmt);
match insert_jumps(arena, allocated, id, needle) { match insert_jumps(arena, allocated, id, needle) {
None => allocated.clone(), None => None,
Some(new) => { Some(new) => {
// jumps were inserted, we must now add a join point // jumps were inserted, we must now add a join point
let params = Vec::from_iter_in( let params = Vec::from_iter_in(
args.iter().map(|(layout, symbol)| Param { args.iter().map(|(layout, symbol, _)| Param {
symbol: *symbol, symbol: *symbol,
layout: *layout, layout: *layout,
borrow: true, borrow: true,
@ -52,16 +52,18 @@ pub fn make_tail_recursive<'a>(
.into_bump_slice(); .into_bump_slice();
// TODO could this be &[]? // TODO could this be &[]?
let args = Vec::from_iter_in(args.iter().map(|t| t.1), arena).into_bump_slice(); let args = Vec::from_iter_in(args.iter().map(|t| t.2), arena).into_bump_slice();
let jump = arena.alloc(Stmt::Jump(id, args)); let jump = arena.alloc(Stmt::Jump(id, args));
Stmt::Join { let join = Stmt::Join {
id, id,
remainder: jump, remainder: jump,
parameters: params, parameters: params,
body: new, body: new,
} };
Some(join)
} }
} }
} }

View file

@ -6,7 +6,7 @@ procedure Num.26 (#Attr.2, #Attr.3):
let Test.12 = lowlevel NumMul #Attr.2 #Attr.3; let Test.12 = lowlevel NumMul #Attr.2 #Attr.3;
ret Test.12; ret Test.12;
procedure Test.1 (Test.2, Test.3): procedure Test.1 (Test.17, Test.18):
joinpoint Test.7 Test.2 Test.3: joinpoint Test.7 Test.2 Test.3:
let Test.15 = 0i64; let Test.15 = 0i64;
let Test.16 = lowlevel Eq Test.15 Test.2; let Test.16 = lowlevel Eq Test.15 Test.2;
@ -18,7 +18,7 @@ procedure Test.1 (Test.2, Test.3):
let Test.11 = CallByName Num.26 Test.2 Test.3; let Test.11 = CallByName Num.26 Test.2 Test.3;
jump Test.7 Test.10 Test.11; jump Test.7 Test.10 Test.11;
in in
jump Test.7 Test.2 Test.3; jump Test.7 Test.17 Test.18;
procedure Test.0 (): procedure Test.0 ():
let Test.5 = 10i64; let Test.5 = 10i64;

View file

@ -1,4 +1,4 @@
procedure Test.3 (Test.4): procedure Test.3 (Test.29):
joinpoint Test.13 Test.4: joinpoint Test.13 Test.4:
let Test.23 = 1i64; let Test.23 = 1i64;
let Test.24 = GetTagId Test.4; let Test.24 = GetTagId Test.4;
@ -18,7 +18,7 @@ procedure Test.3 (Test.4):
let Test.7 = UnionAtIndex (Id 0) (Index 1) Test.4; let Test.7 = UnionAtIndex (Id 0) (Index 1) Test.4;
jump Test.13 Test.7; jump Test.13 Test.7;
in in
jump Test.13 Test.4; jump Test.13 Test.29;
procedure Test.0 (): procedure Test.0 ():
let Test.28 = 3i64; let Test.28 = 3i64;

View file

@ -10,7 +10,7 @@ procedure Num.27 (#Attr.2, #Attr.3):
let Test.26 = lowlevel NumLt #Attr.2 #Attr.3; let Test.26 = lowlevel NumLt #Attr.2 #Attr.3;
ret Test.26; ret Test.26;
procedure Test.1 (Test.2, Test.3, Test.4): procedure Test.1 (Test.29, Test.30, Test.31):
joinpoint Test.12 Test.2 Test.3 Test.4: joinpoint Test.12 Test.2 Test.3 Test.4:
let Test.14 = CallByName Num.27 Test.3 Test.4; let Test.14 = CallByName Num.27 Test.3 Test.4;
if Test.14 then if Test.14 then
@ -29,7 +29,7 @@ procedure Test.1 (Test.2, Test.3, Test.4):
else else
ret Test.2; ret Test.2;
in in
jump Test.12 Test.2 Test.3 Test.4; jump Test.12 Test.29 Test.30 Test.31;
procedure Test.0 (): procedure Test.0 ():
let Test.9 = Array []; let Test.9 = Array [];