This commit is contained in:
J.Teeuwissen 2023-04-29 15:28:25 +02:00
parent c1ced3c5d2
commit dbab89cc64
No known key found for this signature in database
GPG key ID: DB5F7A1ED8D478AD
4 changed files with 333 additions and 52 deletions

View file

@ -13,14 +13,15 @@ use std::iter::Iterator;
use bumpalo::collections::vec::Vec;
use bumpalo::collections::CollectIn;
use roc_module::low_level::LowLevel;
use roc_module::low_level::{LowLevel, LowLevelWrapperType};
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
use roc_target::TargetInfo;
use crate::ir::{
BranchInfo, Call, CallType, Expr, JoinPointId, ModifyRc, Proc, ProcLayout, Stmt, UpdateModeId,
BranchInfo, Call, CallType, Expr, JoinPointId, Literal, ModifyRc, Proc, ProcLayout, Stmt,
UpdateModeId,
};
use crate::layout::{InLayout, Layout, LayoutInterner, STLayoutInterner, UnionLayout};
use crate::layout::{Builtin, InLayout, Layout, LayoutInterner, STLayoutInterner, UnionLayout};
use bumpalo::Bump;
@ -87,16 +88,46 @@ fn specialize_drops_stmt<'a, 'i>(
}
match expr {
Expr::Call(_) | Expr::Struct(_) => {
// TODO perhaps allow for some e.g. lowlevel functions to be called if they cannot modify the RC of the symbol.
Expr::Call(Call {
call_type,
arguments,
}) => {
match call_type {
CallType::ByName { name, .. }
if (LowLevelWrapperType::from_symbol(name.name())
== LowLevelWrapperType::CanBeReplacedBy(
LowLevel::ListGetUnsafe,
)) =>
{
environment.add_list_child(arguments[0], *binding, &arguments[1]);
// Calls can modify the RC of the symbol.
// If we move a increment of children after the function,
// the function might deallocate the child before we can use it after the function.
// If we move the decrement of the parent to before the function,
// the parent might be deallocated before the function can use it.
// Thus forget everything about any increments.
alloc_let_with_continuation!(environment)
}
CallType::LowLevel {
op: LowLevel::ListGetUnsafe,
..
} => {
environment.add_list_child(arguments[0], *binding, &arguments[1]);
alloc_let_with_continuation!(environment)
}
_ => {
// TODO perhaps allow for some e.g. lowlevel functions to be called if they cannot modify the RC of the symbol.
// Calls can modify the RC of the symbol.
// If we move a increment of children after the function,
// the function might deallocate the child before we can use it after the function.
// If we move the decrement of the parent to before the function,
// the parent might be deallocated before the function can use it.
// Thus forget everything about any increments.
let mut new_environment = environment.clone_without_incremented();
alloc_let_with_continuation!(&mut new_environment)
}
}
}
Expr::Struct(_) => {
let mut new_environment = environment.clone_without_incremented();
alloc_let_with_continuation!(&mut new_environment)
@ -147,10 +178,20 @@ fn specialize_drops_stmt<'a, 'i>(
Expr::ResetRef { .. } => {
alloc_let_with_continuation!(environment)
}
Expr::Literal(literal) => {
// literal ints are used to store the the index for lists.
// Add it to the env so when we use it to index a list, we can use the index to specialize the drop.
if let Literal::Int(i) = literal {
environment
.symbol_index
.insert(*binding, i128::from_ne_bytes(*i) as u64);
}
alloc_let_with_continuation!(environment)
}
Expr::RuntimeErrorFunction(_)
| Expr::ExprBox { .. }
| Expr::NullPointer
| Expr::Literal(_)
| Expr::GetTagId { .. }
| Expr::EmptyArray
| Expr::Array { .. } => {
@ -166,19 +207,33 @@ fn specialize_drops_stmt<'a, 'i>(
default_branch,
ret_layout,
} => {
macro_rules! insert_branch_info {
($branch_env:expr,$info:expr ) => {
match $info {
BranchInfo::Constructor {
scrutinee: symbol,
tag_id: tag,
..
} => {
$branch_env.symbol_tag.insert(*symbol, *tag);
}
BranchInfo::List {
scrutinee: symbol,
len,
} => {
$branch_env.list_length.insert(*symbol, *len);
}
_ => (),
}
};
}
let new_branches = branches
.iter()
.map(|(label, info, branch)| {
let mut branch_env = environment.clone_without_incremented();
if let BranchInfo::Constructor {
scrutinee: symbol,
tag_id: tag,
..
} = info
{
branch_env.symbol_tag.insert(*symbol, *tag);
}
insert_branch_info!(branch_env, info);
let new_branch = specialize_drops_stmt(
arena,
@ -198,14 +253,7 @@ fn specialize_drops_stmt<'a, 'i>(
let mut branch_env = environment.clone_without_incremented();
if let BranchInfo::Constructor {
scrutinee: symbol,
tag_id: tag,
..
} = info
{
branch_env.symbol_tag.insert(*symbol, *tag);
}
insert_branch_info!(branch_env, info);
let new_branch = specialize_drops_stmt(
arena,
@ -326,8 +374,17 @@ fn specialize_drops_stmt<'a, 'i>(
symbol,
continuation,
),
Layout::Builtin(Builtin::List(layout)) => specialize_list(
arena,
layout_interner,
ident_ids,
environment,
&mut incremented_children,
symbol,
layout,
continuation,
),
// TODO: lambda sets should not be reachable, yet they are.
// TODO: Implement this with uniqueness checks.
_ => {
let new_continuation = specialize_drops_stmt(
arena,
@ -782,6 +839,160 @@ fn specialize_boxed<'a, 'i>(
}
}
fn specialize_list<'a, 'i>(
arena: &'a Bump,
layout_interner: &'i mut STLayoutInterner<'a>,
ident_ids: &'i mut IdentIds,
environment: &mut DropSpecializationEnvironment<'a>,
incremented_children: &mut MutSet<Child>,
symbol: &Symbol,
item_layout: InLayout,
continuation: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
let current_length = environment.list_length.get(symbol).copied();
macro_rules! keep_original_decrement {
() => {{
let new_continuation =
specialize_drops_stmt(arena, layout_interner, ident_ids, environment, continuation);
arena.alloc(Stmt::Refcounting(ModifyRc::Dec(*symbol), new_continuation))
}};
}
match (
layout_interner.contains_refcounted(item_layout),
current_length,
) {
(true, Some(length)) => {
match environment.list_children.get(symbol) {
Some(children) if children.len() as u64 == length => {
// Only specialize lists if all children are known.
// Otherwise we might have to insert an unbouned number of decrements.
// TODO perhaps this allocation can be avoided.
let children_clone = children.clone();
// Map tracking which index of the struct is contained in which symbol.
// And whether the child no longer has to be decremented.
let mut index_symbols = MutMap::default();
for index in 0..length {
for (child, i) in children_clone
.iter()
.filter(|(_child, i)| *i == index as u64)
{
debug_assert!(length > *i);
let removed = incremented_children.remove(child);
index_symbols.insert(index, (*child, removed));
if removed {
break;
}
}
}
let new_continuation = specialize_drops_stmt(
arena,
layout_interner,
ident_ids,
environment,
continuation,
);
let refcount_items = |rc_popped: Option<
fn(arena: &'a Bump, Symbol, &'a Stmt<'a>) -> &'a Stmt<'a>,
>,
rc_unpopped: Option<
fn(arena: &'a Bump, Symbol, &'a Stmt<'a>) -> &'a Stmt<'a>,
>,
continuation: &'a Stmt<'a>|
-> &'a Stmt<'a> {
let mut new_continuation = continuation;
// Reversed to ensure that the generated code decrements the items in the correct order.
for i in (0..length).rev() {
let (s, popped) = index_symbols.get(&i).unwrap();
new_continuation = {
if *popped {
// This symbol was popped, so we can skip the decrement.
match rc_popped {
Some(rc) => rc(arena, *s, new_continuation),
None => new_continuation,
}
} else {
// This symbol was indexed but not decremented, so we will decrement it.
match rc_unpopped {
Some(rc) => rc(arena, *s, new_continuation),
None => new_continuation,
}
}
};
}
new_continuation
};
branch_uniqueness(
arena,
ident_ids,
layout_interner,
environment,
*symbol,
// If the symbol is unique:
// - drop the children that were not incremented before
// - don't do anything for the children that were incremented before
// - free the parent
|_layout_interner, _ident_ids, continuation| {
refcount_items(
// Do nothing for the children that were incremented before, as the decrement will cancel out.
None,
// Decrement the children that were not incremented before. And thus don't cancel out.
Some(|arena, symbol, continuation| {
arena.alloc(Stmt::Refcounting(
ModifyRc::Dec(symbol),
continuation,
))
}),
arena.alloc(Stmt::Refcounting(
// TODO this could be replaced by a free if ever added to the IR.
ModifyRc::DecRef(*symbol),
continuation,
)),
)
},
// If the symbol is not unique:
// - increment the children that were incremented before
// - don't do anything for the children that were not incremented before
// - decref the parent
|_layout_interner, _ident_ids, continuation| {
refcount_items(
Some(|arena, symbol, continuation| {
arena.alloc(Stmt::Refcounting(
ModifyRc::Inc(symbol, 1),
continuation,
))
}),
None,
arena.alloc(Stmt::Refcounting(
ModifyRc::DecRef(*symbol),
continuation,
)),
)
},
new_continuation,
)
}
_ => keep_original_decrement!(),
}
}
_ => {
// List length is unknown or the children are not reference counted, so we can't specialize.
keep_original_decrement!()
}
}
}
/**
Get the field layouts of a union given a tag.
*/
@ -965,11 +1176,20 @@ struct DropSpecializationEnvironment<'a> {
// Keeps track of which parent symbol is indexed by which child symbol for boxes
box_children: MutMap<Parent, Vec<'a, Child>>,
// Keeps track of which parent symbol is indexed by which child symbol for lists
list_children: MutMap<Parent, Vec<'a, (Child, Index)>>,
// Keeps track of all incremented symbols.
incremented_symbols: MutMap<Symbol, u64>,
// Map containing the curren't known tag of a layout.
// Map containing the current known tag of a layout.
symbol_tag: MutMap<Symbol, Tag>,
// Map containing the current known index value of a symbol.
symbol_index: MutMap<Symbol, Index>,
// Map containing the current known length of a list.
list_length: MutMap<Symbol, u64>,
}
impl<'a> DropSpecializationEnvironment<'a> {
@ -983,8 +1203,11 @@ impl<'a> DropSpecializationEnvironment<'a> {
struct_children: MutMap::default(),
union_children: MutMap::default(),
box_children: MutMap::default(),
list_children: MutMap::default(),
incremented_symbols: MutMap::default(),
symbol_tag: MutMap::default(),
symbol_index: MutMap::default(),
list_length: MutMap::default(),
}
}
@ -998,8 +1221,11 @@ impl<'a> DropSpecializationEnvironment<'a> {
struct_children: self.struct_children.clone(),
union_children: self.union_children.clone(),
box_children: self.box_children.clone(),
list_children: self.list_children.clone(),
incremented_symbols: MutMap::default(),
symbol_tag: self.symbol_tag.clone(),
symbol_index: self.symbol_index.clone(),
list_length: self.list_length.clone(),
}
}
@ -1024,12 +1250,14 @@ impl<'a> DropSpecializationEnvironment<'a> {
.or_insert_with(|| Vec::new_in(self.arena))
.push((child, index));
}
fn add_union_child(&mut self, parent: Parent, child: Child, tag: u16, index: Index) {
self.union_children
.entry(parent)
.or_insert_with(|| Vec::new_in(self.arena))
.push((child, tag, index));
}
fn add_box_child(&mut self, parent: Parent, child: Child) {
self.box_children
.entry(parent)
@ -1037,6 +1265,20 @@ impl<'a> DropSpecializationEnvironment<'a> {
.push(child);
}
fn add_list_child(&mut self, parent: Parent, child: Child, index: &Symbol) {
match self.symbol_index.get(index) {
Some(index) => {
self.list_children
.entry(parent)
.or_insert_with(|| Vec::new_in(self.arena))
.push((child, *index));
}
None => {
// List index is not constant, so we don't know the index of the child.
}
}
}
fn get_children(&self, parent: &Parent) -> Vec<'a, Symbol> {
let mut res = Vec::new_in(self.arena);
@ -1052,6 +1294,10 @@ impl<'a> DropSpecializationEnvironment<'a> {
children.iter().for_each(|child| res.push(*child));
}
if let Some(children) = self.list_children.get(parent) {
children.iter().for_each(|(child, _)| res.push(*child));
}
res
}