mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-03 08:34:33 +00:00
Merge pull request #1682 from rtfeldman/gen_wasm_join
joinpoints for the wasm backend
This commit is contained in:
commit
575aaa1f0b
9 changed files with 209 additions and 117 deletions
|
@ -7,7 +7,7 @@ use parity_wasm::elements::{
|
|||
use roc_collections::all::MutMap;
|
||||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::symbol::Symbol;
|
||||
use roc_mono::ir::{CallType, Expr, Literal, Proc, Stmt};
|
||||
use roc_mono::ir::{CallType, Expr, JoinPointId, Literal, Proc, Stmt};
|
||||
use roc_mono::layout::{Builtin, Layout};
|
||||
|
||||
// Don't allocate any constant data at address zero or near it. Would be valid, but bug-prone.
|
||||
|
@ -23,7 +23,7 @@ struct LabelId(u32);
|
|||
#[derive(Debug)]
|
||||
struct SymbolStorage(LocalId, WasmLayout);
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct WasmLayout {
|
||||
value_type: ValueType,
|
||||
stack_memory: u32,
|
||||
|
@ -69,7 +69,9 @@ pub struct WasmBackend<'a> {
|
|||
// Functions: internal state & IR mappings
|
||||
stack_memory: u32,
|
||||
symbol_storage_map: MutMap<Symbol, SymbolStorage>,
|
||||
// joinpoint_label_map: MutMap<JoinPointId, LabelId>,
|
||||
/// how many blocks deep are we (used for jumps)
|
||||
block_depth: u32,
|
||||
joinpoint_label_map: MutMap<JoinPointId, (u32, std::vec::Vec<LocalId>)>,
|
||||
}
|
||||
|
||||
impl<'a> WasmBackend<'a> {
|
||||
|
@ -92,7 +94,8 @@ impl<'a> WasmBackend<'a> {
|
|||
// Functions: internal state & IR mappings
|
||||
stack_memory: 0,
|
||||
symbol_storage_map: MutMap::default(),
|
||||
// joinpoint_label_map: MutMap::default(),
|
||||
block_depth: 0,
|
||||
joinpoint_label_map: MutMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -182,6 +185,27 @@ impl<'a> WasmBackend<'a> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// start a loop that leaves a value on the stack
|
||||
fn start_loop_with_return(&mut self, value_type: ValueType) {
|
||||
self.block_depth += 1;
|
||||
|
||||
// self.instructions.push(Loop(BlockType::NoResult));
|
||||
self.instructions.push(Loop(BlockType::Value(value_type)));
|
||||
}
|
||||
|
||||
fn start_block(&mut self) {
|
||||
self.block_depth += 1;
|
||||
|
||||
// Our blocks always end with a `return` or `br`,
|
||||
// so they never leave extra values on the stack
|
||||
self.instructions.push(Block(BlockType::NoResult));
|
||||
}
|
||||
|
||||
fn end_block(&mut self) {
|
||||
self.block_depth -= 1;
|
||||
self.instructions.push(End);
|
||||
}
|
||||
|
||||
fn build_stmt(&mut self, stmt: &Stmt<'a>, ret_layout: &Layout<'a>) -> Result<(), String> {
|
||||
match stmt {
|
||||
// This pattern is a simple optimisation to get rid of one local and two instructions per proc.
|
||||
|
@ -228,11 +252,8 @@ impl<'a> WasmBackend<'a> {
|
|||
// or `BrTable`
|
||||
|
||||
// create (number_of_branches - 1) new blocks.
|
||||
//
|
||||
// Every branch ends in a `return`,
|
||||
// so the block leaves no values on the stack
|
||||
for _ in 0..branches.len() {
|
||||
self.instructions.push(Block(BlockType::NoResult));
|
||||
self.start_block()
|
||||
}
|
||||
|
||||
// the LocalId of the symbol that we match on
|
||||
|
@ -262,13 +283,69 @@ impl<'a> WasmBackend<'a> {
|
|||
// (the first branch would have broken out of 1 block,
|
||||
// hence we must generate its code first)
|
||||
for (_, _, branch) in branches.iter() {
|
||||
self.instructions.push(End);
|
||||
self.end_block();
|
||||
|
||||
self.build_stmt(branch, ret_layout)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Stmt::Join {
|
||||
id,
|
||||
parameters,
|
||||
body,
|
||||
remainder,
|
||||
} => {
|
||||
// make locals for join pointer parameters
|
||||
let mut jp_parameter_local_ids = std::vec::Vec::with_capacity(parameters.len());
|
||||
for parameter in parameters.iter() {
|
||||
let wasm_layout = WasmLayout::new(¶meter.layout)?;
|
||||
let local_id = self.insert_local(wasm_layout, parameter.symbol);
|
||||
|
||||
jp_parameter_local_ids.push(local_id);
|
||||
}
|
||||
|
||||
self.start_block();
|
||||
|
||||
self.joinpoint_label_map
|
||||
.insert(*id, (self.block_depth, jp_parameter_local_ids));
|
||||
|
||||
self.build_stmt(remainder, ret_layout)?;
|
||||
|
||||
self.end_block();
|
||||
|
||||
// A `return` inside of a `loop` seems to make it so that the `loop` itself
|
||||
// also "returns" (so, leaves on the stack) a value of the return type.
|
||||
let return_wasm_layout = WasmLayout::new(ret_layout)?;
|
||||
self.start_loop_with_return(return_wasm_layout.value_type);
|
||||
|
||||
self.build_stmt(body, ret_layout)?;
|
||||
|
||||
// ends the loop
|
||||
self.end_block();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Stmt::Jump(id, arguments) => {
|
||||
let (target, locals) = &self.joinpoint_label_map[id];
|
||||
|
||||
// put the arguments on the stack
|
||||
for (symbol, local_id) in arguments.iter().zip(locals.iter()) {
|
||||
let argument = match self.symbol_storage_map.get(symbol) {
|
||||
Some(SymbolStorage(local_id, _)) => local_id.0,
|
||||
None => unreachable!("symbol not defined: {:?}", symbol),
|
||||
};
|
||||
|
||||
self.instructions.push(GetLocal(argument));
|
||||
self.instructions.push(SetLocal(local_id.0));
|
||||
}
|
||||
|
||||
// jump
|
||||
let levels = self.block_depth - target;
|
||||
self.instructions.push(Br(levels));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
x => Err(format!("statement not yet implemented: {:?}", x)),
|
||||
}
|
||||
}
|
||||
|
@ -280,7 +357,7 @@ impl<'a> WasmBackend<'a> {
|
|||
layout: &Layout<'a>,
|
||||
) -> Result<(), String> {
|
||||
match expr {
|
||||
Expr::Literal(lit) => self.load_literal(lit),
|
||||
Expr::Literal(lit) => self.load_literal(lit, layout),
|
||||
|
||||
Expr::Call(roc_mono::ir::Call {
|
||||
call_type,
|
||||
|
@ -308,7 +385,7 @@ impl<'a> WasmBackend<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
fn load_literal(&mut self, lit: &Literal<'a>) -> Result<(), String> {
|
||||
fn load_literal(&mut self, lit: &Literal<'a>, layout: &Layout<'a>) -> Result<(), String> {
|
||||
match lit {
|
||||
Literal::Bool(x) => {
|
||||
self.instructions.push(I32Const(*x as i32));
|
||||
|
@ -319,7 +396,15 @@ impl<'a> WasmBackend<'a> {
|
|||
Ok(())
|
||||
}
|
||||
Literal::Int(x) => {
|
||||
match layout {
|
||||
Layout::Builtin(Builtin::Int32) => {
|
||||
self.instructions.push(I32Const(*x as i32));
|
||||
}
|
||||
Layout::Builtin(Builtin::Int64) => {
|
||||
self.instructions.push(I64Const(*x as i64));
|
||||
}
|
||||
x => panic!("loading literal, {:?}, is not yet implemented", x),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Literal::Float(x) => {
|
||||
|
@ -363,6 +448,22 @@ impl<'a> WasmBackend<'a> {
|
|||
ValueType::F32 => &[F32Add],
|
||||
ValueType::F64 => &[F64Add],
|
||||
},
|
||||
LowLevel::NumSub => match return_value_type {
|
||||
ValueType::I32 => &[I32Sub],
|
||||
ValueType::I64 => &[I64Sub],
|
||||
ValueType::F32 => &[F32Sub],
|
||||
ValueType::F64 => &[F64Sub],
|
||||
},
|
||||
LowLevel::NumMul => match return_value_type {
|
||||
ValueType::I32 => &[I32Mul],
|
||||
ValueType::I64 => &[I64Mul],
|
||||
ValueType::F32 => &[F32Mul],
|
||||
ValueType::F64 => &[F64Mul],
|
||||
},
|
||||
LowLevel::NumGt => {
|
||||
// needs layout of the argument to be implemented fully
|
||||
&[I32GtS]
|
||||
}
|
||||
_ => {
|
||||
return Err(format!("unsupported low-level op {:?}", lowlevel));
|
||||
}
|
||||
|
|
|
@ -25,6 +25,19 @@ pub fn build_module<'a>(
|
|||
let mut backend = WasmBackend::new();
|
||||
let mut layout_ids = LayoutIds::default();
|
||||
|
||||
// Sort procedures by occurrence order
|
||||
//
|
||||
// We sort by the "name", but those are interned strings, and the name that is
|
||||
// interned first will have a lower number.
|
||||
//
|
||||
// But, the name that occurs first is always `main` because it is in the (implicit)
|
||||
// file header. Therefore sorting high to low will put other functions before main
|
||||
//
|
||||
// This means that for now other functions in the file have to be ordered "in reverse": if A
|
||||
// uses B, then the name of A must first occur after the first occurrence of the name of B
|
||||
let mut procedures: std::vec::Vec<_> = procedures.into_iter().collect();
|
||||
procedures.sort_by(|a, b| b.0 .0.cmp(&a.0 .0));
|
||||
|
||||
for ((sym, layout), proc) in procedures {
|
||||
let function_index = backend.build_proc(proc, sym)?;
|
||||
if env.exposed_to_host.contains(&sym) {
|
||||
|
|
|
@ -116,6 +116,44 @@ mod dev_num {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn join_point() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
x = if True then 111 else 222
|
||||
|
||||
x + 123
|
||||
"#
|
||||
),
|
||||
234,
|
||||
i64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factorial() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
app "test" provides [ main ] to "./platform"
|
||||
|
||||
fac : I32, I32 -> I32
|
||||
fac = \n, accum ->
|
||||
if n > 1 then
|
||||
fac (n - 1) (n * accum)
|
||||
else
|
||||
accum
|
||||
|
||||
main : I32
|
||||
main = fac 8 1
|
||||
"#
|
||||
),
|
||||
40_320,
|
||||
i32
|
||||
);
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn gen_add_f64() {
|
||||
// assert_evals_to!(
|
||||
|
|
|
@ -3927,7 +3927,7 @@ fn make_specializations<'a>(
|
|||
);
|
||||
|
||||
let external_specializations_requested = procs.externals_we_need.clone();
|
||||
let procedures = procs.get_specialized_procs_without_rc(mono_env.arena);
|
||||
let procedures = procs.get_specialized_procs_without_rc(&mut mono_env);
|
||||
|
||||
let make_specializations_end = SystemTime::now();
|
||||
module_timing.make_specializations = make_specializations_end
|
||||
|
|
|
@ -272,6 +272,33 @@ impl<'a> Proc<'a> {
|
|||
proc.body = b.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn make_tail_recursive(&mut self, env: &mut Env<'a, '_>) {
|
||||
let mut args = Vec::with_capacity_in(self.args.len(), env.arena);
|
||||
let mut proc_args = Vec::with_capacity_in(self.args.len(), env.arena);
|
||||
|
||||
for (layout, symbol) in self.args {
|
||||
let new = env.unique_symbol();
|
||||
args.push((*layout, *symbol, new));
|
||||
proc_args.push((*layout, new));
|
||||
}
|
||||
|
||||
use self::SelfRecursive::*;
|
||||
if let SelfRecursive(id) = self.is_self_recursive {
|
||||
let transformed = crate::tail_recursion::make_tail_recursive(
|
||||
env.arena,
|
||||
id,
|
||||
self.name,
|
||||
self.body.clone(),
|
||||
args.into_bump_slice(),
|
||||
);
|
||||
|
||||
if let Some(with_tco) = transformed {
|
||||
self.body = with_tco;
|
||||
self.args = proc_args.into_bump_slice();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -350,7 +377,7 @@ pub enum InProgressProc<'a> {
|
|||
impl<'a> Procs<'a> {
|
||||
pub fn get_specialized_procs_without_rc(
|
||||
self,
|
||||
arena: &'a Bump,
|
||||
env: &mut Env<'a, '_>,
|
||||
) -> MutMap<(Symbol, ProcLayout<'a>), Proc<'a>> {
|
||||
let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
|
||||
|
||||
|
@ -376,16 +403,7 @@ impl<'a> Procs<'a> {
|
|||
panic!();
|
||||
}
|
||||
Done(mut proc) => {
|
||||
use self::SelfRecursive::*;
|
||||
if let SelfRecursive(id) = proc.is_self_recursive {
|
||||
proc.body = crate::tail_recursion::make_tail_recursive(
|
||||
arena,
|
||||
id,
|
||||
proc.name,
|
||||
proc.body.clone(),
|
||||
proc.args,
|
||||
);
|
||||
}
|
||||
proc.make_tail_recursive(env);
|
||||
|
||||
result.insert(key, proc);
|
||||
}
|
||||
|
@ -395,86 +413,6 @@ impl<'a> Procs<'a> {
|
|||
result
|
||||
}
|
||||
|
||||
// TODO investigate make this an iterator?
|
||||
pub fn get_specialized_procs(
|
||||
self,
|
||||
arena: &'a Bump,
|
||||
) -> MutMap<(Symbol, ProcLayout<'a>), Proc<'a>> {
|
||||
let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
|
||||
|
||||
for ((s, toplevel), in_prog_proc) in self.specialized.into_iter() {
|
||||
match in_prog_proc {
|
||||
InProgress => unreachable!(
|
||||
"The procedure {:?} should have be done by now",
|
||||
(s, toplevel)
|
||||
),
|
||||
Done(proc) => {
|
||||
result.insert((s, toplevel), proc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, proc) in result.iter_mut() {
|
||||
use self::SelfRecursive::*;
|
||||
if let SelfRecursive(id) = proc.is_self_recursive {
|
||||
proc.body = crate::tail_recursion::make_tail_recursive(
|
||||
arena,
|
||||
id,
|
||||
proc.name,
|
||||
proc.body.clone(),
|
||||
proc.args,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result));
|
||||
|
||||
crate::inc_dec::visit_procs(arena, borrow_params, &mut result);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn get_specialized_procs_help(
|
||||
self,
|
||||
arena: &'a Bump,
|
||||
) -> (
|
||||
MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>,
|
||||
&'a crate::borrow::ParamMap<'a>,
|
||||
) {
|
||||
let mut result = MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
|
||||
|
||||
for ((s, toplevel), in_prog_proc) in self.specialized.into_iter() {
|
||||
match in_prog_proc {
|
||||
InProgress => unreachable!(
|
||||
"The procedure {:?} should have be done by now",
|
||||
(s, toplevel)
|
||||
),
|
||||
Done(proc) => {
|
||||
result.insert((s, toplevel), proc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, proc) in result.iter_mut() {
|
||||
use self::SelfRecursive::*;
|
||||
if let SelfRecursive(id) = proc.is_self_recursive {
|
||||
proc.body = crate::tail_recursion::make_tail_recursive(
|
||||
arena,
|
||||
id,
|
||||
proc.name,
|
||||
proc.body.clone(),
|
||||
proc.args,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let borrow_params = arena.alloc(crate::borrow::infer_borrow(arena, &result));
|
||||
|
||||
crate::inc_dec::visit_procs(arena, borrow_params, &mut result);
|
||||
|
||||
(result, borrow_params)
|
||||
}
|
||||
|
||||
// TODO trim down these arguments!
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn insert_named(
|
||||
|
|
|
@ -33,16 +33,16 @@ pub fn make_tail_recursive<'a>(
|
|||
id: JoinPointId,
|
||||
needle: Symbol,
|
||||
stmt: Stmt<'a>,
|
||||
args: &'a [(Layout<'a>, Symbol)],
|
||||
) -> Stmt<'a> {
|
||||
args: &'a [(Layout<'a>, Symbol, Symbol)],
|
||||
) -> Option<Stmt<'a>> {
|
||||
let allocated = arena.alloc(stmt);
|
||||
match insert_jumps(arena, allocated, id, needle) {
|
||||
None => allocated.clone(),
|
||||
None => None,
|
||||
Some(new) => {
|
||||
// jumps were inserted, we must now add a join point
|
||||
|
||||
let params = Vec::from_iter_in(
|
||||
args.iter().map(|(layout, symbol)| Param {
|
||||
args.iter().map(|(layout, symbol, _)| Param {
|
||||
symbol: *symbol,
|
||||
layout: *layout,
|
||||
borrow: true,
|
||||
|
@ -52,16 +52,18 @@ pub fn make_tail_recursive<'a>(
|
|||
.into_bump_slice();
|
||||
|
||||
// TODO could this be &[]?
|
||||
let args = Vec::from_iter_in(args.iter().map(|t| t.1), arena).into_bump_slice();
|
||||
let args = Vec::from_iter_in(args.iter().map(|t| t.2), arena).into_bump_slice();
|
||||
|
||||
let jump = arena.alloc(Stmt::Jump(id, args));
|
||||
|
||||
Stmt::Join {
|
||||
let join = Stmt::Join {
|
||||
id,
|
||||
remainder: jump,
|
||||
parameters: params,
|
||||
body: new,
|
||||
}
|
||||
};
|
||||
|
||||
Some(join)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ procedure Num.26 (#Attr.2, #Attr.3):
|
|||
let Test.12 = lowlevel NumMul #Attr.2 #Attr.3;
|
||||
ret Test.12;
|
||||
|
||||
procedure Test.1 (Test.2, Test.3):
|
||||
procedure Test.1 (Test.17, Test.18):
|
||||
joinpoint Test.7 Test.2 Test.3:
|
||||
let Test.15 = 0i64;
|
||||
let Test.16 = lowlevel Eq Test.15 Test.2;
|
||||
|
@ -18,7 +18,7 @@ procedure Test.1 (Test.2, Test.3):
|
|||
let Test.11 = CallByName Num.26 Test.2 Test.3;
|
||||
jump Test.7 Test.10 Test.11;
|
||||
in
|
||||
jump Test.7 Test.2 Test.3;
|
||||
jump Test.7 Test.17 Test.18;
|
||||
|
||||
procedure Test.0 ():
|
||||
let Test.5 = 10i64;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
procedure Test.3 (Test.4):
|
||||
procedure Test.3 (Test.29):
|
||||
joinpoint Test.13 Test.4:
|
||||
let Test.23 = 1i64;
|
||||
let Test.24 = GetTagId Test.4;
|
||||
|
@ -18,7 +18,7 @@ procedure Test.3 (Test.4):
|
|||
let Test.7 = UnionAtIndex (Id 0) (Index 1) Test.4;
|
||||
jump Test.13 Test.7;
|
||||
in
|
||||
jump Test.13 Test.4;
|
||||
jump Test.13 Test.29;
|
||||
|
||||
procedure Test.0 ():
|
||||
let Test.28 = 3i64;
|
||||
|
|
|
@ -10,7 +10,7 @@ procedure Num.27 (#Attr.2, #Attr.3):
|
|||
let Test.26 = lowlevel NumLt #Attr.2 #Attr.3;
|
||||
ret Test.26;
|
||||
|
||||
procedure Test.1 (Test.2, Test.3, Test.4):
|
||||
procedure Test.1 (Test.29, Test.30, Test.31):
|
||||
joinpoint Test.12 Test.2 Test.3 Test.4:
|
||||
let Test.14 = CallByName Num.27 Test.3 Test.4;
|
||||
if Test.14 then
|
||||
|
@ -29,7 +29,7 @@ procedure Test.1 (Test.2, Test.3, Test.4):
|
|||
else
|
||||
ret Test.2;
|
||||
in
|
||||
jump Test.12 Test.2 Test.3 Test.4;
|
||||
jump Test.12 Test.29 Test.30 Test.31;
|
||||
|
||||
procedure Test.0 ():
|
||||
let Test.9 = Array [];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue