Start drop specialisation for joinpoints

This commit is contained in:
J.Teeuwissen 2023-05-29 16:26:50 +02:00 committed by Folkert
parent 2faa0e4c5b
commit 94fb89bde4
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
12 changed files with 365 additions and 190 deletions

View file

@ -26,7 +26,7 @@ use crate::layout::{
use bumpalo::Bump;
use roc_collections::MutMap;
use roc_collections::{MutMap, MutSet};
/**
Try to find increments of symbols followed by decrements of the symbol they were indexed out of (their parent).
@ -375,9 +375,39 @@ fn specialize_drops_stmt<'a, 'i>(
}};
}
environment.jump_incremented_symbols =
new_default_branch.2.jump_incremented_symbols.clone();
let newer_branches = new_branches
.iter()
.map(|(label, info, branch, branch_env)| {
for (joinpoint, current_incremented_symbols) in
environment.jump_incremented_symbols.iter_mut()
{
match branch_env.jump_incremented_symbols.get(joinpoint) {
Some(branch_incremented_symbols) => {
let mut to_remove = MutSet::default();
for (key, join_count) in current_incremented_symbols.map.iter_mut()
{
match branch_incremented_symbols.map.get(key) {
Some(count) => {
*join_count = std::cmp::min(*join_count, *count);
}
None => {
to_remove.insert(*key);
}
}
}
for key in to_remove.iter() {
current_incremented_symbols.map.remove(key);
}
}
None => {
// Do nothing
}
}
}
let new_branch = insert_incs!(branch_env, branch);
(*label, info.clone(), new_branch.clone())
@ -624,21 +654,78 @@ fn specialize_drops_stmt<'a, 'i>(
body,
remainder,
} => {
let mut new_environment = environment.clone();
new_environment.incremented_symbols.clear();
let mut remainder_enviroment = environment.clone();
let _new_remainder = specialize_drops_stmt(
arena,
layout_interner,
ident_ids,
&mut remainder_enviroment,
remainder,
);
let remainder_jump_incremented_symbols = remainder_enviroment
.jump_incremented_symbols
.get(id)
.map_or_else(|| CountingMap::new(), |map| map.clone());
let mut body_environment = environment.clone();
for param in parameters.iter() {
new_environment.add_symbol_layout(param.symbol, param.layout);
body_environment.add_symbol_layout(param.symbol, param.layout);
}
let mut body_jump_incremented_symbols = remainder_jump_incremented_symbols;
// Perform iteration to get the incremented_symbols for the body.
let joinpoint_usage = loop {
// Update the incremented_symbols to the remainder's incremented_symbols.
let mut current_body_environment = body_environment.clone();
current_body_environment.incremented_symbols =
body_jump_incremented_symbols.clone();
current_body_environment
.jump_incremented_symbols
.insert(*id, body_jump_incremented_symbols.clone());
let _new_body = specialize_drops_stmt(
arena,
layout_interner,
ident_ids,
&mut current_body_environment,
body,
);
let new_body_jump_incremented_symbols = current_body_environment
.jump_incremented_symbols
.get(id)
.expect("Jump incremented symbols should be present.")
.clone();
if body_jump_incremented_symbols == new_body_jump_incremented_symbols {
break new_body_jump_incremented_symbols;
} else {
body_jump_incremented_symbols = new_body_jump_incremented_symbols;
}
};
let join_joinpoint_usage = arena.alloc(joinpoint_usage.clone());
body_environment.incremented_symbols = joinpoint_usage;
body_environment
.join_incremented_symbols
.insert(*id, join_joinpoint_usage);
let new_body = specialize_drops_stmt(
arena,
layout_interner,
ident_ids,
&mut new_environment,
&mut body_environment,
body,
);
environment
.join_incremented_symbols
.insert(*id, join_joinpoint_usage);
arena.alloc(Stmt::Join {
id: *id,
parameters,
@ -652,7 +739,29 @@ fn specialize_drops_stmt<'a, 'i>(
),
})
}
Stmt::Jump(joinpoint_id, arguments) => arena.alloc(Stmt::Jump(*joinpoint_id, arguments)),
Stmt::Jump(joinpoint_id, arguments) => {
match environment.join_incremented_symbols.get(joinpoint_id) {
Some(join_usage) => {
// Consume all symbols that were consumed in the join.
for (symbol, count) in join_usage.map.iter() {
for _ in 0..*count {
let popped = environment.incremented_symbols.pop(symbol);
debug_assert!(
popped,
"Every incremented symbol should be available from jumps"
);
}
}
}
None => {
// No join usage, let the join know the minimum amount of symbols that were incremented from each jump.
environment
.jump_incremented_symbols
.insert(*joinpoint_id, environment.incremented_symbols.clone());
}
}
arena.alloc(Stmt::Jump(*joinpoint_id, arguments))
}
Stmt::Crash(symbol, crash_tag) => arena.alloc(Stmt::Crash(*symbol, *crash_tag)),
}
}
@ -1297,6 +1406,12 @@ struct DropSpecializationEnvironment<'a> {
// Map containing the current known length of a list.
list_length: MutMap<Symbol, u64>,
// A map containing the minimum number of symbol increments from jumps for a joinpoint.
jump_incremented_symbols: MutMap<JoinPointId, CountingMap<Symbol>>,
// A map containing the expected number of symbol increments from joinpoints for a jump.
join_incremented_symbols: MutMap<JoinPointId, &'a CountingMap<Symbol>>,
}
impl<'a> DropSpecializationEnvironment<'a> {
@ -1314,6 +1429,8 @@ impl<'a> DropSpecializationEnvironment<'a> {
symbol_tag: MutMap::default(),
symbol_index: MutMap::default(),
list_length: MutMap::default(),
jump_incremented_symbols: MutMap::default(),
join_incremented_symbols: MutMap::default(),
}
}
@ -1501,8 +1618,8 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC {
/// Map that contains a count for each key.
/// Keys with a count of 0 are kept around, so that it can be seen that they were once present.
#[derive(Clone)]
struct CountingMap<K> {
#[derive(Clone, PartialEq, Eq)]
struct CountingMap<K: std::cmp::Eq + std::hash::Hash> {
map: MutMap<K, u64>,
}
@ -1549,8 +1666,4 @@ where
}
res
}
fn clear(&mut self) {
self.map.clear();
}
}