This commit is contained in:
Folkert 2023-06-18 14:21:48 +02:00
parent 4a9514d2c4
commit 0247237fe8
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
16 changed files with 625 additions and 163 deletions

View file

@ -1,11 +1,88 @@
#![allow(clippy::manual_map)]
use crate::borrow::Ownership;
use crate::ir::{Call, CallType, Env, Expr, JoinPointId, Param, Proc, SelfRecursive, Stmt};
use crate::layout::{InLayout, LambdaName, LayoutInterner, LayoutRepr, TagIdIntType, UnionLayout};
use crate::ir::{
Call, CallType, Expr, JoinPointId, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId,
};
use crate::layout::{
InLayout, LambdaName, Layout, LayoutInterner, LayoutRepr, STLayoutInterner, TagIdIntType,
UnionLayout,
};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_module::symbol::Symbol;
use roc_collections::MutMap;
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
pub struct Env<'a, 'i> {
arena: &'a Bump,
home: ModuleId,
interner: &'i mut STLayoutInterner<'a>,
ident_ids: &'i mut IdentIds,
}
impl<'a, 'i> Env<'a, 'i> {
pub fn unique_symbol(&mut self) -> Symbol {
let ident_id = self.ident_ids.gen_unique();
Symbol::new(self.home, ident_id)
}
pub fn named_unique_symbol(&mut self, name: &str) -> Symbol {
let ident_id = self.ident_ids.add_str(name);
Symbol::new(self.home, ident_id)
}
}
pub fn apply_trmc<'a, 'i>(
arena: &'a Bump,
interner: &'i mut STLayoutInterner<'a>,
home: ModuleId,
ident_ids: &'i mut IdentIds,
procs: &mut MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>,
) {
let mut env = Env {
arena,
interner,
home,
ident_ids,
};
let env = &mut env;
for (_, proc) in procs {
use self::SelfRecursive::*;
if let SelfRecursive(id) = proc.is_self_recursive {
if crate::tail_recursion::is_trmc_candidate(env.interner, proc) {
let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc);
*proc = new_proc;
} else {
let mut args = Vec::with_capacity_in(proc.args.len(), arena);
let mut proc_args = Vec::with_capacity_in(proc.args.len(), arena);
for (layout, symbol) in proc.args {
let new = env.unique_symbol();
args.push((*layout, *symbol, new));
proc_args.push((*layout, new));
}
let transformed = crate::tail_recursion::make_tail_recursive(
arena,
id,
proc.name,
proc.body.clone(),
args.into_bump_slice(),
proc.ret_layout,
);
if let Some(with_tco) = transformed {
proc.body = with_tco;
proc.args = proc_args.into_bump_slice();
}
}
}
}
}
/// Make tail calls into loops (using join points)
///
@ -325,7 +402,7 @@ fn insert_jumps<'a>(
}
}
pub(crate) fn is_trmc_candidate<'a, I>(interner: &I, proc: &Proc<'a>) -> bool
pub(crate) fn is_trmc_candidate<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> bool
where
I: LayoutInterner<'a>,
{
@ -338,10 +415,68 @@ where
}
// and return a recursive tag union
match interner.get_repr(proc.ret_layout) {
LayoutRepr::Union(union_layout) => union_layout.is_recursive(),
_ => false,
if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive())
{
return false;
}
has_cons_in_tail_position(&proc.body, proc.name)
}
fn has_cons_in_tail_position(initial_stmt: &Stmt<'_>, function_name: LambdaName) -> bool {
// we are looking for code of the form
//
// let x = Tag a b c
// ret x
let mut stack = vec![(None, initial_stmt)];
while let Some((recursive_call, stmt)) = stack.pop() {
match stmt {
Stmt::Let(symbol, expr, _, next) => {
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
// must use the result of a recursive call directly as an argument
if let Some(recursive_call) = recursive_call {
if cons_info.arguments.contains(&recursive_call) {
return true;
}
}
}
let recursive_call = recursive_call
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
stack.push((recursive_call, next));
}
Stmt::Switch {
branches,
default_branch,
..
} => {
for (_, _, stmt) in branches.iter() {
stack.push((recursive_call, stmt));
}
stack.push((recursive_call, default_branch.1));
}
Stmt::Refcounting(_, next) => {
stack.push((recursive_call, next));
}
Stmt::Expect { remainder, .. }
| Stmt::ExpectFx { remainder, .. }
| Stmt::Dbg { remainder, .. } => {
stack.push((recursive_call, remainder));
}
Stmt::Join {
body, remainder, ..
} => {
stack.push((recursive_call, body));
stack.push((recursive_call, remainder));
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
}
}
false
}
#[derive(Clone)]
@ -358,6 +493,7 @@ pub(crate) struct TrmcEnv<'a> {
recursive_call: Option<(Symbol, Call<'a>)>,
}
#[derive(Debug)]
struct ConstructorInfo<'a> {
tag_layout: UnionLayout<'a>,
tag_id: TagIdIntType,
@ -365,25 +501,18 @@ struct ConstructorInfo<'a> {
}
impl<'a> TrmcEnv<'a> {
fn is_recursive_expr(&mut self, expr: &Expr<'a>) -> Option<Call<'a>> {
if let Expr::Call(call) = expr {
self.is_recursive_call(call).then_some(call.clone())
} else {
None
}
}
fn is_terminal_constructor(&mut self, stmt: &Stmt<'a>) -> Option<ConstructorInfo<'a>> {
#[inline(always)]
fn is_terminal_constructor(stmt: &Stmt<'a>) -> Option<ConstructorInfo<'a>> {
match stmt {
Stmt::Let(s1, expr, _layout, Stmt::Ret(s2)) if s1 == s2 => {
self.get_contructor_info(expr)
Self::get_contructor_info(expr)
}
_ => None,
}
}
fn get_contructor_info(&mut self, expr: &Expr<'a>) -> Option<ConstructorInfo<'a>> {
fn get_contructor_info(expr: &Expr<'a>) -> Option<ConstructorInfo<'a>> {
if let Expr::Tag {
tag_layout,
tag_id,
@ -402,16 +531,19 @@ impl<'a> TrmcEnv<'a> {
}
}
fn is_recursive_call(&mut self, call: &Call<'a>) -> bool {
fn is_recursive_expr(expr: &Expr<'a>, lambda_name: LambdaName<'_>) -> Option<Call<'a>> {
if let Expr::Call(call) = expr {
Self::is_recursive_call(call, lambda_name).then_some(call.clone())
} else {
None
}
}
fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool {
match call.call_type {
CallType::ByName {
name,
ret_layout,
arg_layouts,
specialization_id,
} => {
CallType::ByName { name, .. } => {
// TODO are there other restrictions?
name == self.function_name
name == lambda_name
}
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
false
@ -421,16 +553,16 @@ impl<'a> TrmcEnv<'a> {
fn ptr_write(
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
return_layout: InLayout<'a>,
_return_layout: InLayout<'a>,
ptr: Symbol,
value: Symbol,
next: &'a Stmt<'a>,
) -> Stmt<'a> {
let box_write = Call {
call_type: crate::ir::CallType::LowLevel {
op: roc_module::low_level::LowLevel::PtrWrite,
update_mode: env.next_update_mode_id(),
op: LowLevel::PtrStore,
// update_mode: env.next_update_mode_id(),
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([ptr, value]),
};
@ -438,16 +570,13 @@ impl<'a> TrmcEnv<'a> {
Stmt::Let(
env.named_unique_symbol("_ptr_write_unit"),
Expr::Call(box_write),
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)),
// interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)),
Layout::UNIT,
next,
)
}
pub fn init(
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
proc: &Proc<'a>,
) -> Proc<'a> {
pub fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>) -> Proc<'a> {
let arena = env.arena;
let return_layout = proc.ret_layout;
@ -475,8 +604,9 @@ impl<'a> TrmcEnv<'a> {
let null_symbol = env.named_unique_symbol("null");
let let_null = |next| Stmt::Let(null_symbol, Expr::NullPointer, return_layout, next);
let box_return_layout =
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout));
let box_return_layout = env
.interner
.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout));
let box_null = Expr::ExprBox {
symbol: null_symbol,
};
@ -508,7 +638,7 @@ impl<'a> TrmcEnv<'a> {
let joinpoint = Stmt::Join {
id: joinpoint_id,
parameters: joinpoint_parameters.into_bump_slice(),
body: arena.alloc(this.walk_stmt(env, interner, &proc.body)),
body: arena.alloc(this.walk_stmt(env, &proc.body)),
remainder: arena.alloc(jump_stmt),
};
@ -534,24 +664,19 @@ impl<'a> TrmcEnv<'a> {
}
}
fn walk_stmt(
&mut self,
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
stmt: &Stmt<'a>,
) -> Stmt<'a> {
fn walk_stmt(&mut self, env: &mut Env<'a, '_>, stmt: &Stmt<'a>) -> Stmt<'a> {
let arena = env.arena;
match stmt {
Stmt::Let(symbol, expr, layout, next) => {
if self.recursive_call.is_none() {
if let Some(call) = self.is_recursive_expr(expr) {
if let Some(call) = Self::is_recursive_expr(expr, self.function_name) {
self.recursive_call = Some((*symbol, call));
return self.walk_stmt(env, interner, next);
return self.walk_stmt(env, next);
}
}
if let Some(cons_info) = self.is_terminal_constructor(stmt) {
if let Some(cons_info) = Self::is_terminal_constructor(stmt) {
match &self.recursive_call {
None => {
// this control flow path did not encounter a recursive call. Just
@ -561,7 +686,7 @@ impl<'a> TrmcEnv<'a> {
let output = define_tag(arena.alloc(
//
self.non_trmc_return(env, interner, *symbol),
self.non_trmc_return(env, *symbol),
));
return output;
@ -571,11 +696,21 @@ impl<'a> TrmcEnv<'a> {
// branch.
// TODO remove unwrap. also what if the symbol occurs more than once?
let recursive_field_index = cons_info
.arguments
.iter()
.position(|s| *s == *call_symbol)
.unwrap();
let opt_recursive_field_index =
cons_info.arguments.iter().position(|s| *s == *call_symbol);
let recursive_field_index = match opt_recursive_field_index {
None => {
let next = self.walk_stmt(env, next);
return Stmt::Let(
*symbol,
expr.clone(),
*layout,
arena.alloc(next),
);
}
Some(v) => v,
};
let mut arguments =
Vec::from_iter_in(cons_info.arguments.iter().copied(), env.arena);
@ -589,8 +724,11 @@ impl<'a> TrmcEnv<'a> {
let let_tag = |next| Stmt::Let(*symbol, tag_expr, *layout, next);
let get_reference_expr = Expr::ExprBox {
symbol: self.null_symbol,
let get_reference_expr = Expr::UnionFieldPtrAtIndex {
structure: *symbol,
tag_id: cons_info.tag_id,
union_layout: cons_info.tag_layout,
index: recursive_field_index as _,
};
let new_hole_symbol = env.named_unique_symbol("newHole");
@ -616,7 +754,6 @@ impl<'a> TrmcEnv<'a> {
//
Self::ptr_write(
env,
interner,
*layout,
self.hole_symbol,
*symbol,
@ -630,7 +767,7 @@ impl<'a> TrmcEnv<'a> {
}
}
let next = self.walk_stmt(env, interner, next);
let next = self.walk_stmt(env, next);
Stmt::Let(*symbol, expr.clone(), *layout, arena.alloc(next))
}
Stmt::Switch {
@ -646,14 +783,13 @@ impl<'a> TrmcEnv<'a> {
for (id, info, stmt) in branches.iter() {
self.recursive_call = opt_recursive_call.clone();
let new_stmt = self.walk_stmt(env, interner, stmt);
let new_stmt = self.walk_stmt(env, stmt);
new_branches.push((*id, info.clone(), new_stmt));
}
self.recursive_call = opt_recursive_call;
let new_default_branch =
&*arena.alloc(self.walk_stmt(env, interner, default_branch.1));
let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1));
Stmt::Switch {
cond_symbol: *cond_symbol,
@ -666,42 +802,92 @@ impl<'a> TrmcEnv<'a> {
Stmt::Ret(symbol) => {
// write the symbol we're supposed to return into the hole
// then read initial_symbol and return its contents
self.non_trmc_return(env, interner, *symbol)
self.non_trmc_return(env, *symbol)
}
Stmt::Refcounting(_, _) => todo!(),
Stmt::Expect { .. } => todo!(),
Stmt::ExpectFx { .. } => todo!(),
Stmt::Dbg { .. } => todo!(),
Stmt::Join { .. } => todo!(),
Stmt::Jump(_, _) => todo!(),
Stmt::Crash(_, _) => todo!(),
Stmt::Refcounting(op, next) => {
let new_next = self.walk_stmt(env, next);
Stmt::Refcounting(*op, arena.alloc(new_next))
}
Stmt::Expect {
condition,
region,
lookups,
variables,
remainder,
} => Stmt::Expect {
condition: *condition,
region: *region,
lookups,
variables,
remainder: arena.alloc(self.walk_stmt(env, remainder)),
},
Stmt::ExpectFx {
condition,
region,
lookups,
variables,
remainder,
} => Stmt::Expect {
condition: *condition,
region: *region,
lookups,
variables,
remainder: arena.alloc(self.walk_stmt(env, remainder)),
},
Stmt::Dbg {
symbol,
variable,
remainder,
} => Stmt::Dbg {
symbol: *symbol,
variable: *variable,
remainder: arena.alloc(self.walk_stmt(env, remainder)),
},
Stmt::Join {
id,
parameters,
body,
remainder,
} => {
let new_body = self.walk_stmt(env, body);
let new_remainder = self.walk_stmt(env, remainder);
Stmt::Join {
id: *id,
parameters,
body: arena.alloc(new_body),
remainder: arena.alloc(new_remainder),
}
}
Stmt::Jump(id, arguments) => Stmt::Jump(*id, arguments),
Stmt::Crash(symbol, crash_tag) => Stmt::Crash(*symbol, *crash_tag),
}
}
fn non_trmc_return(
&mut self,
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
value_symbol: Symbol,
) -> Stmt<'a> {
fn non_trmc_return(&mut self, env: &mut Env<'a, '_>, value_symbol: Symbol) -> Stmt<'a> {
let arena = env.arena;
let layout = self.return_layout;
let unbox_expr = Expr::ExprUnbox {
symbol: self.initial_box_symbol,
};
let final_symbol = env.named_unique_symbol("final");
let unbox = |next| Stmt::Let(final_symbol, unbox_expr, layout, next);
let call = Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrLoad,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: &*arena.alloc([self.initial_box_symbol]),
};
let ptr_load = |next| Stmt::Let(final_symbol, Expr::Call(call), layout, next);
Self::ptr_write(
env,
interner,
layout,
self.hole_symbol,
value_symbol,
arena.alloc(
//
unbox(arena.alloc(Stmt::Ret(final_symbol))),
ptr_load(arena.alloc(Stmt::Ret(final_symbol))),
),
)
}