mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-27 05:49:08 +00:00
Start drop specialisation for joinpoints
This commit is contained in:
parent
2faa0e4c5b
commit
94fb89bde4
12 changed files with 365 additions and 190 deletions
|
@ -26,7 +26,7 @@ use crate::layout::{
|
|||
|
||||
use bumpalo::Bump;
|
||||
|
||||
use roc_collections::MutMap;
|
||||
use roc_collections::{MutMap, MutSet};
|
||||
|
||||
/**
|
||||
Try to find increments of symbols followed by decrements of the symbol they were indexed out of (their parent).
|
||||
|
@ -375,9 +375,39 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
}};
|
||||
}
|
||||
|
||||
environment.jump_incremented_symbols =
|
||||
new_default_branch.2.jump_incremented_symbols.clone();
|
||||
|
||||
let newer_branches = new_branches
|
||||
.iter()
|
||||
.map(|(label, info, branch, branch_env)| {
|
||||
for (joinpoint, current_incremented_symbols) in
|
||||
environment.jump_incremented_symbols.iter_mut()
|
||||
{
|
||||
match branch_env.jump_incremented_symbols.get(joinpoint) {
|
||||
Some(branch_incremented_symbols) => {
|
||||
let mut to_remove = MutSet::default();
|
||||
for (key, join_count) in current_incremented_symbols.map.iter_mut()
|
||||
{
|
||||
match branch_incremented_symbols.map.get(key) {
|
||||
Some(count) => {
|
||||
*join_count = std::cmp::min(*join_count, *count);
|
||||
}
|
||||
None => {
|
||||
to_remove.insert(*key);
|
||||
}
|
||||
}
|
||||
}
|
||||
for key in to_remove.iter() {
|
||||
current_incremented_symbols.map.remove(key);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Do nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let new_branch = insert_incs!(branch_env, branch);
|
||||
|
||||
(*label, info.clone(), new_branch.clone())
|
||||
|
@ -624,21 +654,78 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
body,
|
||||
remainder,
|
||||
} => {
|
||||
let mut new_environment = environment.clone();
|
||||
new_environment.incremented_symbols.clear();
|
||||
let mut remainder_enviroment = environment.clone();
|
||||
|
||||
let _new_remainder = specialize_drops_stmt(
|
||||
arena,
|
||||
layout_interner,
|
||||
ident_ids,
|
||||
&mut remainder_enviroment,
|
||||
remainder,
|
||||
);
|
||||
|
||||
let remainder_jump_incremented_symbols = remainder_enviroment
|
||||
.jump_incremented_symbols
|
||||
.get(id)
|
||||
.map_or_else(|| CountingMap::new(), |map| map.clone());
|
||||
|
||||
let mut body_environment = environment.clone();
|
||||
for param in parameters.iter() {
|
||||
new_environment.add_symbol_layout(param.symbol, param.layout);
|
||||
body_environment.add_symbol_layout(param.symbol, param.layout);
|
||||
}
|
||||
|
||||
let mut body_jump_incremented_symbols = remainder_jump_incremented_symbols;
|
||||
|
||||
// Perform iteration to get the incremented_symbols for the body.
|
||||
let joinpoint_usage = loop {
|
||||
// Update the incremented_symbols to the remainder's incremented_symbols.
|
||||
let mut current_body_environment = body_environment.clone();
|
||||
current_body_environment.incremented_symbols =
|
||||
body_jump_incremented_symbols.clone();
|
||||
current_body_environment
|
||||
.jump_incremented_symbols
|
||||
.insert(*id, body_jump_incremented_symbols.clone());
|
||||
|
||||
let _new_body = specialize_drops_stmt(
|
||||
arena,
|
||||
layout_interner,
|
||||
ident_ids,
|
||||
&mut current_body_environment,
|
||||
body,
|
||||
);
|
||||
|
||||
let new_body_jump_incremented_symbols = current_body_environment
|
||||
.jump_incremented_symbols
|
||||
.get(id)
|
||||
.expect("Jump incremented symbols should be present.")
|
||||
.clone();
|
||||
|
||||
if body_jump_incremented_symbols == new_body_jump_incremented_symbols {
|
||||
break new_body_jump_incremented_symbols;
|
||||
} else {
|
||||
body_jump_incremented_symbols = new_body_jump_incremented_symbols;
|
||||
}
|
||||
};
|
||||
|
||||
let join_joinpoint_usage = arena.alloc(joinpoint_usage.clone());
|
||||
|
||||
body_environment.incremented_symbols = joinpoint_usage;
|
||||
body_environment
|
||||
.join_incremented_symbols
|
||||
.insert(*id, join_joinpoint_usage);
|
||||
|
||||
let new_body = specialize_drops_stmt(
|
||||
arena,
|
||||
layout_interner,
|
||||
ident_ids,
|
||||
&mut new_environment,
|
||||
&mut body_environment,
|
||||
body,
|
||||
);
|
||||
|
||||
environment
|
||||
.join_incremented_symbols
|
||||
.insert(*id, join_joinpoint_usage);
|
||||
|
||||
arena.alloc(Stmt::Join {
|
||||
id: *id,
|
||||
parameters,
|
||||
|
@ -652,7 +739,29 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
),
|
||||
})
|
||||
}
|
||||
Stmt::Jump(joinpoint_id, arguments) => arena.alloc(Stmt::Jump(*joinpoint_id, arguments)),
|
||||
Stmt::Jump(joinpoint_id, arguments) => {
|
||||
match environment.join_incremented_symbols.get(joinpoint_id) {
|
||||
Some(join_usage) => {
|
||||
// Consume all symbols that were consumed in the join.
|
||||
for (symbol, count) in join_usage.map.iter() {
|
||||
for _ in 0..*count {
|
||||
let popped = environment.incremented_symbols.pop(symbol);
|
||||
debug_assert!(
|
||||
popped,
|
||||
"Every incremented symbol should be available from jumps"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// No join usage, let the join know the minimum amount of symbols that were incremented from each jump.
|
||||
environment
|
||||
.jump_incremented_symbols
|
||||
.insert(*joinpoint_id, environment.incremented_symbols.clone());
|
||||
}
|
||||
}
|
||||
arena.alloc(Stmt::Jump(*joinpoint_id, arguments))
|
||||
}
|
||||
Stmt::Crash(symbol, crash_tag) => arena.alloc(Stmt::Crash(*symbol, *crash_tag)),
|
||||
}
|
||||
}
|
||||
|
@ -1297,6 +1406,12 @@ struct DropSpecializationEnvironment<'a> {
|
|||
|
||||
// Map containing the current known length of a list.
|
||||
list_length: MutMap<Symbol, u64>,
|
||||
|
||||
// A map containing the minimum number of symbol increments from jumps for a joinpoint.
|
||||
jump_incremented_symbols: MutMap<JoinPointId, CountingMap<Symbol>>,
|
||||
|
||||
// A map containing the expected number of symbol increments from joinpoints for a jump.
|
||||
join_incremented_symbols: MutMap<JoinPointId, &'a CountingMap<Symbol>>,
|
||||
}
|
||||
|
||||
impl<'a> DropSpecializationEnvironment<'a> {
|
||||
|
@ -1314,6 +1429,8 @@ impl<'a> DropSpecializationEnvironment<'a> {
|
|||
symbol_tag: MutMap::default(),
|
||||
symbol_index: MutMap::default(),
|
||||
list_length: MutMap::default(),
|
||||
jump_incremented_symbols: MutMap::default(),
|
||||
join_incremented_symbols: MutMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1501,8 +1618,8 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC {
|
|||
|
||||
/// Map that contains a count for each key.
|
||||
/// Keys with a count of 0 are kept around, so that it can be seen that they were once present.
|
||||
#[derive(Clone)]
|
||||
struct CountingMap<K> {
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
struct CountingMap<K: std::cmp::Eq + std::hash::Hash> {
|
||||
map: MutMap<K, u64>,
|
||||
}
|
||||
|
||||
|
@ -1549,8 +1666,4 @@ where
|
|||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.map.clear();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue