mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 21:39:07 +00:00
WIP
This commit is contained in:
parent
4a9514d2c4
commit
0247237fe8
16 changed files with 625 additions and 163 deletions
|
@ -1,11 +1,88 @@
|
|||
#![allow(clippy::manual_map)]
|
||||
|
||||
use crate::borrow::Ownership;
|
||||
use crate::ir::{Call, CallType, Env, Expr, JoinPointId, Param, Proc, SelfRecursive, Stmt};
|
||||
use crate::layout::{InLayout, LambdaName, LayoutInterner, LayoutRepr, TagIdIntType, UnionLayout};
|
||||
use crate::ir::{
|
||||
Call, CallType, Expr, JoinPointId, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId,
|
||||
};
|
||||
use crate::layout::{
|
||||
InLayout, LambdaName, Layout, LayoutInterner, LayoutRepr, STLayoutInterner, TagIdIntType,
|
||||
UnionLayout,
|
||||
};
|
||||
use bumpalo::collections::Vec;
|
||||
use bumpalo::Bump;
|
||||
use roc_module::symbol::Symbol;
|
||||
use roc_collections::MutMap;
|
||||
use roc_module::low_level::LowLevel;
|
||||
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
|
||||
|
||||
pub struct Env<'a, 'i> {
|
||||
arena: &'a Bump,
|
||||
home: ModuleId,
|
||||
interner: &'i mut STLayoutInterner<'a>,
|
||||
ident_ids: &'i mut IdentIds,
|
||||
}
|
||||
|
||||
impl<'a, 'i> Env<'a, 'i> {
|
||||
pub fn unique_symbol(&mut self) -> Symbol {
|
||||
let ident_id = self.ident_ids.gen_unique();
|
||||
|
||||
Symbol::new(self.home, ident_id)
|
||||
}
|
||||
|
||||
pub fn named_unique_symbol(&mut self, name: &str) -> Symbol {
|
||||
let ident_id = self.ident_ids.add_str(name);
|
||||
Symbol::new(self.home, ident_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_trmc<'a, 'i>(
|
||||
arena: &'a Bump,
|
||||
interner: &'i mut STLayoutInterner<'a>,
|
||||
home: ModuleId,
|
||||
ident_ids: &'i mut IdentIds,
|
||||
procs: &mut MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>,
|
||||
) {
|
||||
let mut env = Env {
|
||||
arena,
|
||||
interner,
|
||||
home,
|
||||
ident_ids,
|
||||
};
|
||||
|
||||
let env = &mut env;
|
||||
|
||||
for (_, proc) in procs {
|
||||
use self::SelfRecursive::*;
|
||||
if let SelfRecursive(id) = proc.is_self_recursive {
|
||||
if crate::tail_recursion::is_trmc_candidate(env.interner, proc) {
|
||||
let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc);
|
||||
*proc = new_proc;
|
||||
} else {
|
||||
let mut args = Vec::with_capacity_in(proc.args.len(), arena);
|
||||
let mut proc_args = Vec::with_capacity_in(proc.args.len(), arena);
|
||||
|
||||
for (layout, symbol) in proc.args {
|
||||
let new = env.unique_symbol();
|
||||
args.push((*layout, *symbol, new));
|
||||
proc_args.push((*layout, new));
|
||||
}
|
||||
|
||||
let transformed = crate::tail_recursion::make_tail_recursive(
|
||||
arena,
|
||||
id,
|
||||
proc.name,
|
||||
proc.body.clone(),
|
||||
args.into_bump_slice(),
|
||||
proc.ret_layout,
|
||||
);
|
||||
|
||||
if let Some(with_tco) = transformed {
|
||||
proc.body = with_tco;
|
||||
proc.args = proc_args.into_bump_slice();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make tail calls into loops (using join points)
|
||||
///
|
||||
|
@ -325,7 +402,7 @@ fn insert_jumps<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_trmc_candidate<'a, I>(interner: &I, proc: &Proc<'a>) -> bool
|
||||
pub(crate) fn is_trmc_candidate<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> bool
|
||||
where
|
||||
I: LayoutInterner<'a>,
|
||||
{
|
||||
|
@ -338,10 +415,68 @@ where
|
|||
}
|
||||
|
||||
// and return a recursive tag union
|
||||
match interner.get_repr(proc.ret_layout) {
|
||||
LayoutRepr::Union(union_layout) => union_layout.is_recursive(),
|
||||
_ => false,
|
||||
if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
has_cons_in_tail_position(&proc.body, proc.name)
|
||||
}
|
||||
|
||||
fn has_cons_in_tail_position(initial_stmt: &Stmt<'_>, function_name: LambdaName) -> bool {
|
||||
// we are looking for code of the form
|
||||
//
|
||||
// let x = Tag a b c
|
||||
// ret x
|
||||
|
||||
let mut stack = vec![(None, initial_stmt)];
|
||||
|
||||
while let Some((recursive_call, stmt)) = stack.pop() {
|
||||
match stmt {
|
||||
Stmt::Let(symbol, expr, _, next) => {
|
||||
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
|
||||
// must use the result of a recursive call directly as an argument
|
||||
if let Some(recursive_call) = recursive_call {
|
||||
if cons_info.arguments.contains(&recursive_call) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let recursive_call = recursive_call
|
||||
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
|
||||
|
||||
stack.push((recursive_call, next));
|
||||
}
|
||||
Stmt::Switch {
|
||||
branches,
|
||||
default_branch,
|
||||
..
|
||||
} => {
|
||||
for (_, _, stmt) in branches.iter() {
|
||||
stack.push((recursive_call, stmt));
|
||||
}
|
||||
stack.push((recursive_call, default_branch.1));
|
||||
}
|
||||
Stmt::Refcounting(_, next) => {
|
||||
stack.push((recursive_call, next));
|
||||
}
|
||||
Stmt::Expect { remainder, .. }
|
||||
| Stmt::ExpectFx { remainder, .. }
|
||||
| Stmt::Dbg { remainder, .. } => {
|
||||
stack.push((recursive_call, remainder));
|
||||
}
|
||||
Stmt::Join {
|
||||
body, remainder, ..
|
||||
} => {
|
||||
stack.push((recursive_call, body));
|
||||
stack.push((recursive_call, remainder));
|
||||
}
|
||||
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
|
@ -358,6 +493,7 @@ pub(crate) struct TrmcEnv<'a> {
|
|||
recursive_call: Option<(Symbol, Call<'a>)>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConstructorInfo<'a> {
|
||||
tag_layout: UnionLayout<'a>,
|
||||
tag_id: TagIdIntType,
|
||||
|
@ -365,25 +501,18 @@ struct ConstructorInfo<'a> {
|
|||
}
|
||||
|
||||
impl<'a> TrmcEnv<'a> {
|
||||
fn is_recursive_expr(&mut self, expr: &Expr<'a>) -> Option<Call<'a>> {
|
||||
if let Expr::Call(call) = expr {
|
||||
self.is_recursive_call(call).then_some(call.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn is_terminal_constructor(&mut self, stmt: &Stmt<'a>) -> Option<ConstructorInfo<'a>> {
|
||||
#[inline(always)]
|
||||
fn is_terminal_constructor(stmt: &Stmt<'a>) -> Option<ConstructorInfo<'a>> {
|
||||
match stmt {
|
||||
Stmt::Let(s1, expr, _layout, Stmt::Ret(s2)) if s1 == s2 => {
|
||||
self.get_contructor_info(expr)
|
||||
Self::get_contructor_info(expr)
|
||||
}
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_contructor_info(&mut self, expr: &Expr<'a>) -> Option<ConstructorInfo<'a>> {
|
||||
fn get_contructor_info(expr: &Expr<'a>) -> Option<ConstructorInfo<'a>> {
|
||||
if let Expr::Tag {
|
||||
tag_layout,
|
||||
tag_id,
|
||||
|
@ -402,16 +531,19 @@ impl<'a> TrmcEnv<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
fn is_recursive_call(&mut self, call: &Call<'a>) -> bool {
|
||||
fn is_recursive_expr(expr: &Expr<'a>, lambda_name: LambdaName<'_>) -> Option<Call<'a>> {
|
||||
if let Expr::Call(call) = expr {
|
||||
Self::is_recursive_call(call, lambda_name).then_some(call.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool {
|
||||
match call.call_type {
|
||||
CallType::ByName {
|
||||
name,
|
||||
ret_layout,
|
||||
arg_layouts,
|
||||
specialization_id,
|
||||
} => {
|
||||
CallType::ByName { name, .. } => {
|
||||
// TODO are there other restrictions?
|
||||
name == self.function_name
|
||||
name == lambda_name
|
||||
}
|
||||
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
|
||||
false
|
||||
|
@ -421,16 +553,16 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
fn ptr_write(
|
||||
env: &mut Env<'a, '_>,
|
||||
interner: &mut impl LayoutInterner<'a>,
|
||||
return_layout: InLayout<'a>,
|
||||
_return_layout: InLayout<'a>,
|
||||
ptr: Symbol,
|
||||
value: Symbol,
|
||||
next: &'a Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
let box_write = Call {
|
||||
call_type: crate::ir::CallType::LowLevel {
|
||||
op: roc_module::low_level::LowLevel::PtrWrite,
|
||||
update_mode: env.next_update_mode_id(),
|
||||
op: LowLevel::PtrStore,
|
||||
// update_mode: env.next_update_mode_id(),
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: env.arena.alloc([ptr, value]),
|
||||
};
|
||||
|
@ -438,16 +570,13 @@ impl<'a> TrmcEnv<'a> {
|
|||
Stmt::Let(
|
||||
env.named_unique_symbol("_ptr_write_unit"),
|
||||
Expr::Call(box_write),
|
||||
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)),
|
||||
// interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)),
|
||||
Layout::UNIT,
|
||||
next,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn init(
|
||||
env: &mut Env<'a, '_>,
|
||||
interner: &mut impl LayoutInterner<'a>,
|
||||
proc: &Proc<'a>,
|
||||
) -> Proc<'a> {
|
||||
pub fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>) -> Proc<'a> {
|
||||
let arena = env.arena;
|
||||
let return_layout = proc.ret_layout;
|
||||
|
||||
|
@ -475,8 +604,9 @@ impl<'a> TrmcEnv<'a> {
|
|||
let null_symbol = env.named_unique_symbol("null");
|
||||
let let_null = |next| Stmt::Let(null_symbol, Expr::NullPointer, return_layout, next);
|
||||
|
||||
let box_return_layout =
|
||||
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout));
|
||||
let box_return_layout = env
|
||||
.interner
|
||||
.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout));
|
||||
let box_null = Expr::ExprBox {
|
||||
symbol: null_symbol,
|
||||
};
|
||||
|
@ -508,7 +638,7 @@ impl<'a> TrmcEnv<'a> {
|
|||
let joinpoint = Stmt::Join {
|
||||
id: joinpoint_id,
|
||||
parameters: joinpoint_parameters.into_bump_slice(),
|
||||
body: arena.alloc(this.walk_stmt(env, interner, &proc.body)),
|
||||
body: arena.alloc(this.walk_stmt(env, &proc.body)),
|
||||
remainder: arena.alloc(jump_stmt),
|
||||
};
|
||||
|
||||
|
@ -534,24 +664,19 @@ impl<'a> TrmcEnv<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
fn walk_stmt(
|
||||
&mut self,
|
||||
env: &mut Env<'a, '_>,
|
||||
interner: &mut impl LayoutInterner<'a>,
|
||||
stmt: &Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
fn walk_stmt(&mut self, env: &mut Env<'a, '_>, stmt: &Stmt<'a>) -> Stmt<'a> {
|
||||
let arena = env.arena;
|
||||
|
||||
match stmt {
|
||||
Stmt::Let(symbol, expr, layout, next) => {
|
||||
if self.recursive_call.is_none() {
|
||||
if let Some(call) = self.is_recursive_expr(expr) {
|
||||
if let Some(call) = Self::is_recursive_expr(expr, self.function_name) {
|
||||
self.recursive_call = Some((*symbol, call));
|
||||
return self.walk_stmt(env, interner, next);
|
||||
return self.walk_stmt(env, next);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cons_info) = self.is_terminal_constructor(stmt) {
|
||||
if let Some(cons_info) = Self::is_terminal_constructor(stmt) {
|
||||
match &self.recursive_call {
|
||||
None => {
|
||||
// this control flow path did not encounter a recursive call. Just
|
||||
|
@ -561,7 +686,7 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
let output = define_tag(arena.alloc(
|
||||
//
|
||||
self.non_trmc_return(env, interner, *symbol),
|
||||
self.non_trmc_return(env, *symbol),
|
||||
));
|
||||
|
||||
return output;
|
||||
|
@ -571,11 +696,21 @@ impl<'a> TrmcEnv<'a> {
|
|||
// branch.
|
||||
|
||||
// TODO remove unwrap. also what if the symbol occurs more than once?
|
||||
let recursive_field_index = cons_info
|
||||
.arguments
|
||||
.iter()
|
||||
.position(|s| *s == *call_symbol)
|
||||
.unwrap();
|
||||
let opt_recursive_field_index =
|
||||
cons_info.arguments.iter().position(|s| *s == *call_symbol);
|
||||
|
||||
let recursive_field_index = match opt_recursive_field_index {
|
||||
None => {
|
||||
let next = self.walk_stmt(env, next);
|
||||
return Stmt::Let(
|
||||
*symbol,
|
||||
expr.clone(),
|
||||
*layout,
|
||||
arena.alloc(next),
|
||||
);
|
||||
}
|
||||
Some(v) => v,
|
||||
};
|
||||
|
||||
let mut arguments =
|
||||
Vec::from_iter_in(cons_info.arguments.iter().copied(), env.arena);
|
||||
|
@ -589,8 +724,11 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
let let_tag = |next| Stmt::Let(*symbol, tag_expr, *layout, next);
|
||||
|
||||
let get_reference_expr = Expr::ExprBox {
|
||||
symbol: self.null_symbol,
|
||||
let get_reference_expr = Expr::UnionFieldPtrAtIndex {
|
||||
structure: *symbol,
|
||||
tag_id: cons_info.tag_id,
|
||||
union_layout: cons_info.tag_layout,
|
||||
index: recursive_field_index as _,
|
||||
};
|
||||
|
||||
let new_hole_symbol = env.named_unique_symbol("newHole");
|
||||
|
@ -616,7 +754,6 @@ impl<'a> TrmcEnv<'a> {
|
|||
//
|
||||
Self::ptr_write(
|
||||
env,
|
||||
interner,
|
||||
*layout,
|
||||
self.hole_symbol,
|
||||
*symbol,
|
||||
|
@ -630,7 +767,7 @@ impl<'a> TrmcEnv<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
let next = self.walk_stmt(env, interner, next);
|
||||
let next = self.walk_stmt(env, next);
|
||||
Stmt::Let(*symbol, expr.clone(), *layout, arena.alloc(next))
|
||||
}
|
||||
Stmt::Switch {
|
||||
|
@ -646,14 +783,13 @@ impl<'a> TrmcEnv<'a> {
|
|||
|
||||
for (id, info, stmt) in branches.iter() {
|
||||
self.recursive_call = opt_recursive_call.clone();
|
||||
let new_stmt = self.walk_stmt(env, interner, stmt);
|
||||
let new_stmt = self.walk_stmt(env, stmt);
|
||||
|
||||
new_branches.push((*id, info.clone(), new_stmt));
|
||||
}
|
||||
|
||||
self.recursive_call = opt_recursive_call;
|
||||
let new_default_branch =
|
||||
&*arena.alloc(self.walk_stmt(env, interner, default_branch.1));
|
||||
let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1));
|
||||
|
||||
Stmt::Switch {
|
||||
cond_symbol: *cond_symbol,
|
||||
|
@ -666,42 +802,92 @@ impl<'a> TrmcEnv<'a> {
|
|||
Stmt::Ret(symbol) => {
|
||||
// write the symbol we're supposed to return into the hole
|
||||
// then read initial_symbol and return its contents
|
||||
self.non_trmc_return(env, interner, *symbol)
|
||||
self.non_trmc_return(env, *symbol)
|
||||
}
|
||||
Stmt::Refcounting(_, _) => todo!(),
|
||||
Stmt::Expect { .. } => todo!(),
|
||||
Stmt::ExpectFx { .. } => todo!(),
|
||||
Stmt::Dbg { .. } => todo!(),
|
||||
Stmt::Join { .. } => todo!(),
|
||||
Stmt::Jump(_, _) => todo!(),
|
||||
Stmt::Crash(_, _) => todo!(),
|
||||
Stmt::Refcounting(op, next) => {
|
||||
let new_next = self.walk_stmt(env, next);
|
||||
Stmt::Refcounting(*op, arena.alloc(new_next))
|
||||
}
|
||||
Stmt::Expect {
|
||||
condition,
|
||||
region,
|
||||
lookups,
|
||||
variables,
|
||||
remainder,
|
||||
} => Stmt::Expect {
|
||||
condition: *condition,
|
||||
region: *region,
|
||||
lookups,
|
||||
variables,
|
||||
remainder: arena.alloc(self.walk_stmt(env, remainder)),
|
||||
},
|
||||
Stmt::ExpectFx {
|
||||
condition,
|
||||
region,
|
||||
lookups,
|
||||
variables,
|
||||
remainder,
|
||||
} => Stmt::Expect {
|
||||
condition: *condition,
|
||||
region: *region,
|
||||
lookups,
|
||||
variables,
|
||||
remainder: arena.alloc(self.walk_stmt(env, remainder)),
|
||||
},
|
||||
Stmt::Dbg {
|
||||
symbol,
|
||||
variable,
|
||||
remainder,
|
||||
} => Stmt::Dbg {
|
||||
symbol: *symbol,
|
||||
variable: *variable,
|
||||
remainder: arena.alloc(self.walk_stmt(env, remainder)),
|
||||
},
|
||||
Stmt::Join {
|
||||
id,
|
||||
parameters,
|
||||
body,
|
||||
remainder,
|
||||
} => {
|
||||
let new_body = self.walk_stmt(env, body);
|
||||
let new_remainder = self.walk_stmt(env, remainder);
|
||||
|
||||
Stmt::Join {
|
||||
id: *id,
|
||||
parameters,
|
||||
body: arena.alloc(new_body),
|
||||
remainder: arena.alloc(new_remainder),
|
||||
}
|
||||
}
|
||||
Stmt::Jump(id, arguments) => Stmt::Jump(*id, arguments),
|
||||
Stmt::Crash(symbol, crash_tag) => Stmt::Crash(*symbol, *crash_tag),
|
||||
}
|
||||
}
|
||||
|
||||
fn non_trmc_return(
|
||||
&mut self,
|
||||
env: &mut Env<'a, '_>,
|
||||
interner: &mut impl LayoutInterner<'a>,
|
||||
value_symbol: Symbol,
|
||||
) -> Stmt<'a> {
|
||||
fn non_trmc_return(&mut self, env: &mut Env<'a, '_>, value_symbol: Symbol) -> Stmt<'a> {
|
||||
let arena = env.arena;
|
||||
let layout = self.return_layout;
|
||||
|
||||
let unbox_expr = Expr::ExprUnbox {
|
||||
symbol: self.initial_box_symbol,
|
||||
};
|
||||
let final_symbol = env.named_unique_symbol("final");
|
||||
let unbox = |next| Stmt::Let(final_symbol, unbox_expr, layout, next);
|
||||
|
||||
let call = Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::PtrLoad,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: &*arena.alloc([self.initial_box_symbol]),
|
||||
};
|
||||
|
||||
let ptr_load = |next| Stmt::Let(final_symbol, Expr::Call(call), layout, next);
|
||||
|
||||
Self::ptr_write(
|
||||
env,
|
||||
interner,
|
||||
layout,
|
||||
self.hole_symbol,
|
||||
value_symbol,
|
||||
arena.alloc(
|
||||
//
|
||||
unbox(arena.alloc(Stmt::Ret(final_symbol))),
|
||||
ptr_load(arena.alloc(Stmt::Ret(final_symbol))),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue