rough implementation

This commit is contained in:
Folkert 2023-06-17 20:46:59 +02:00
parent 76dcb75ff6
commit 4a9514d2c4
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
7 changed files with 456 additions and 23 deletions

View file

@ -1,8 +1,8 @@
#![allow(clippy::manual_map)]
use crate::borrow::Ownership;
use crate::ir::{CallType, Expr, JoinPointId, Param, Stmt};
use crate::layout::{InLayout, LambdaName};
use crate::ir::{Call, CallType, Env, Expr, JoinPointId, Param, Proc, SelfRecursive, Stmt};
use crate::layout::{InLayout, LambdaName, LayoutInterner, LayoutRepr, TagIdIntType, UnionLayout};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_module::symbol::Symbol;
@ -29,6 +29,7 @@ use roc_module::symbol::Symbol;
///
/// This will effectively compile into a loop in llvm, and
/// won't grow the call stack for each iteration
pub fn make_tail_recursive<'a>(
arena: &'a Bump,
id: JoinPointId,
@ -323,3 +324,385 @@ fn insert_jumps<'a>(
Crash(..) => None,
}
}
pub(crate) fn is_trmc_candidate<'a, I>(interner: &I, proc: &Proc<'a>) -> bool
where
I: LayoutInterner<'a>,
{
// it must be a self-recursive function
if !matches!(
proc.is_self_recursive,
crate::ir::SelfRecursive::SelfRecursive(_)
) {
return false;
}
// and return a recursive tag union
match interner.get_repr(proc.ret_layout) {
LayoutRepr::Union(union_layout) => union_layout.is_recursive(),
_ => false,
}
}
#[derive(Clone)]
pub(crate) struct TrmcEnv<'a> {
function_name: LambdaName<'a>,
hole_symbol: Symbol,
null_symbol: Symbol,
initial_box_symbol: Symbol,
joinpoint_id: JoinPointId,
return_layout: InLayout<'a>,
box_return_layout: InLayout<'a>,
// the call we are performing TRMC on
recursive_call: Option<(Symbol, Call<'a>)>,
}
struct ConstructorInfo<'a> {
tag_layout: UnionLayout<'a>,
tag_id: TagIdIntType,
arguments: &'a [Symbol],
}
impl<'a> TrmcEnv<'a> {
fn is_recursive_expr(&mut self, expr: &Expr<'a>) -> Option<Call<'a>> {
if let Expr::Call(call) = expr {
self.is_recursive_call(call).then_some(call.clone())
} else {
None
}
}
fn is_terminal_constructor(&mut self, stmt: &Stmt<'a>) -> Option<ConstructorInfo<'a>> {
match stmt {
Stmt::Let(s1, expr, _layout, Stmt::Ret(s2)) if s1 == s2 => {
self.get_contructor_info(expr)
}
_ => None,
}
}
fn get_contructor_info(&mut self, expr: &Expr<'a>) -> Option<ConstructorInfo<'a>> {
if let Expr::Tag {
tag_layout,
tag_id,
arguments,
} = expr
{
let info = ConstructorInfo {
tag_layout: *tag_layout,
tag_id: *tag_id,
arguments,
};
Some(info)
} else {
None
}
}
fn is_recursive_call(&mut self, call: &Call<'a>) -> bool {
match call.call_type {
CallType::ByName {
name,
ret_layout,
arg_layouts,
specialization_id,
} => {
// TODO are there other restrictions?
name == self.function_name
}
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
false
}
}
}
fn ptr_write(
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
return_layout: InLayout<'a>,
ptr: Symbol,
value: Symbol,
next: &'a Stmt<'a>,
) -> Stmt<'a> {
let box_write = Call {
call_type: crate::ir::CallType::LowLevel {
op: roc_module::low_level::LowLevel::PtrWrite,
update_mode: env.next_update_mode_id(),
},
arguments: env.arena.alloc([ptr, value]),
};
Stmt::Let(
env.named_unique_symbol("_ptr_write_unit"),
Expr::Call(box_write),
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)),
next,
)
}
pub fn init(
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
proc: &Proc<'a>,
) -> Proc<'a> {
let arena = env.arena;
let return_layout = proc.ret_layout;
let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
let mut new_proc_arguments = Vec::with_capacity_in(proc.args.len(), env.arena);
let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
for (i, (layout, old_symbol)) in proc.args.iter().enumerate() {
let symbol = env.named_unique_symbol(&format!("arg_{i}"));
new_proc_arguments.push((*layout, symbol));
jump_arguments.push(symbol);
let param = Param {
symbol: *old_symbol,
ownership: Ownership::Owned,
layout: *layout,
};
joinpoint_parameters.push(param);
}
// the root of the recursive structure that we'll be building
let initial_box_symbol = env.named_unique_symbol("initial");
jump_arguments.push(initial_box_symbol);
let null_symbol = env.named_unique_symbol("null");
let let_null = |next| Stmt::Let(null_symbol, Expr::NullPointer, return_layout, next);
let box_return_layout =
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout));
let box_null = Expr::ExprBox {
symbol: null_symbol,
};
let let_box = |next| Stmt::Let(initial_box_symbol, box_null, box_return_layout, next);
let joinpoint_id = JoinPointId(env.named_unique_symbol("trmc"));
let hole_symbol = env.named_unique_symbol("hole");
let jump_stmt = Stmt::Jump(joinpoint_id, jump_arguments.into_bump_slice());
let mut this = Self {
function_name: proc.name,
hole_symbol,
null_symbol,
initial_box_symbol,
joinpoint_id,
return_layout,
box_return_layout,
recursive_call: None,
};
let param = Param {
symbol: hole_symbol,
ownership: Ownership::Owned,
layout: box_return_layout,
};
joinpoint_parameters.push(param);
let joinpoint = Stmt::Join {
id: joinpoint_id,
parameters: joinpoint_parameters.into_bump_slice(),
body: arena.alloc(this.walk_stmt(env, interner, &proc.body)),
remainder: arena.alloc(jump_stmt),
};
let body = let_null(arena.alloc(
//
let_box(arena.alloc(
//
joinpoint,
)),
));
#[cfg(debug_assertions)]
env.home.register_debug_idents(env.ident_ids);
Proc {
name: proc.name,
args: new_proc_arguments.into_bump_slice(),
body,
closure_data_layout: proc.closure_data_layout,
ret_layout: proc.ret_layout,
is_self_recursive: SelfRecursive::NotSelfRecursive,
host_exposed_layouts: proc.host_exposed_layouts.clone(),
}
}
fn walk_stmt(
&mut self,
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
stmt: &Stmt<'a>,
) -> Stmt<'a> {
let arena = env.arena;
match stmt {
Stmt::Let(symbol, expr, layout, next) => {
if self.recursive_call.is_none() {
if let Some(call) = self.is_recursive_expr(expr) {
self.recursive_call = Some((*symbol, call));
return self.walk_stmt(env, interner, next);
}
}
if let Some(cons_info) = self.is_terminal_constructor(stmt) {
match &self.recursive_call {
None => {
// this control flow path did not encounter a recursive call. Just
// write the end result into the hole and we're done.
let define_tag = |next| Stmt::Let(*symbol, expr.clone(), *layout, next);
let output = define_tag(arena.alloc(
//
self.non_trmc_return(env, interner, *symbol),
));
return output;
}
Some((call_symbol, call)) => {
// we did encounter a recursive call, and can perform TRMC in this
// branch.
// TODO remove unwrap. also what if the symbol occurs more than once?
let recursive_field_index = cons_info
.arguments
.iter()
.position(|s| *s == *call_symbol)
.unwrap();
let mut arguments =
Vec::from_iter_in(cons_info.arguments.iter().copied(), env.arena);
arguments[recursive_field_index] = self.null_symbol;
let tag_expr = Expr::Tag {
tag_layout: cons_info.tag_layout,
tag_id: cons_info.tag_id,
arguments: arguments.into_bump_slice(),
};
let let_tag = |next| Stmt::Let(*symbol, tag_expr, *layout, next);
let get_reference_expr = Expr::ExprBox {
symbol: self.null_symbol,
};
let new_hole_symbol = env.named_unique_symbol("newHole");
let let_new_hole = |next| {
Stmt::Let(
new_hole_symbol,
get_reference_expr,
self.box_return_layout,
next,
)
};
let mut jump_arguments =
Vec::from_iter_in(call.arguments.iter().copied(), env.arena);
jump_arguments.push(new_hole_symbol);
let jump =
Stmt::Jump(self.joinpoint_id, jump_arguments.into_bump_slice());
let output = let_tag(arena.alloc(
//
let_new_hole(arena.alloc(
//
Self::ptr_write(
env,
interner,
*layout,
self.hole_symbol,
*symbol,
arena.alloc(jump),
),
)),
));
return output;
}
}
}
let next = self.walk_stmt(env, interner, next);
Stmt::Let(*symbol, expr.clone(), *layout, arena.alloc(next))
}
Stmt::Switch {
cond_symbol,
cond_layout,
branches,
default_branch,
ret_layout,
} => {
let mut new_branches = Vec::with_capacity_in(branches.len(), arena);
let opt_recursive_call = self.recursive_call.clone();
for (id, info, stmt) in branches.iter() {
self.recursive_call = opt_recursive_call.clone();
let new_stmt = self.walk_stmt(env, interner, stmt);
new_branches.push((*id, info.clone(), new_stmt));
}
self.recursive_call = opt_recursive_call;
let new_default_branch =
&*arena.alloc(self.walk_stmt(env, interner, default_branch.1));
Stmt::Switch {
cond_symbol: *cond_symbol,
cond_layout: *cond_layout,
branches: &*arena.alloc(new_branches.into_bump_slice()),
default_branch: (default_branch.0.clone(), new_default_branch),
ret_layout: *ret_layout,
}
}
Stmt::Ret(symbol) => {
// write the symbol we're supposed to return into the hole
// then read initial_symbol and return its contents
self.non_trmc_return(env, interner, *symbol)
}
Stmt::Refcounting(_, _) => todo!(),
Stmt::Expect { .. } => todo!(),
Stmt::ExpectFx { .. } => todo!(),
Stmt::Dbg { .. } => todo!(),
Stmt::Join { .. } => todo!(),
Stmt::Jump(_, _) => todo!(),
Stmt::Crash(_, _) => todo!(),
}
}
fn non_trmc_return(
&mut self,
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
value_symbol: Symbol,
) -> Stmt<'a> {
let arena = env.arena;
let layout = self.return_layout;
let unbox_expr = Expr::ExprUnbox {
symbol: self.initial_box_symbol,
};
let final_symbol = env.named_unique_symbol("final");
let unbox = |next| Stmt::Let(final_symbol, unbox_expr, layout, next);
Self::ptr_write(
env,
interner,
layout,
self.hole_symbol,
value_symbol,
arena.alloc(
//
unbox(arena.alloc(Stmt::Ret(final_symbol))),
),
)
}
}