using CountingMap for children

This commit is contained in:
J.Teeuwissen 2023-05-28 23:10:24 +02:00
parent 8f022d4310
commit dbebaf17a6
No known key found for this signature in database
GPG key ID: DB5F7A1ED8D478AD

View file

@ -27,7 +27,7 @@ use crate::layout::{
use bumpalo::Bump;
use roc_collections::{MutMap, MutSet};
use roc_collections::MutMap;
/**
Try to find increments of symbols followed by decrements of the symbol they were indexed out of (their parent).
@ -317,10 +317,12 @@ fn specialize_drops_stmt<'a, 'i>(
};
// Find the lowest symbol count for each symbol in each branch, and update the environment to match.
for (symbol, count) in environment.incremented_symbols.iter_mut() {
for (symbol, count) in environment.incremented_symbols.map.iter_mut() {
let consumed = branch_envs
.iter()
.map(|branch_env| branch_env.incremented_symbols.get(symbol).unwrap_or(&0))
.map(|branch_env| {
branch_env.incremented_symbols.map.get(symbol).unwrap_or(&0)
})
.min()
.unwrap();
@ -334,10 +336,14 @@ fn specialize_drops_stmt<'a, 'i>(
let symbol_differences =
environment
.incremented_symbols
.map
.iter()
.filter_map(|(symbol, count)| {
let branch_count =
$branch_env.incremented_symbols.get(symbol).unwrap_or(&0);
let branch_count = $branch_env
.incremented_symbols
.map
.get(symbol)
.unwrap_or(&0);
match branch_count - count {
0 => None,
@ -375,6 +381,7 @@ fn specialize_drops_stmt<'a, 'i>(
// Remove all 0 counts as cleanup.
environment
.incremented_symbols
.map
.retain(|_, count| *count > 0);
arena.alloc(Stmt::Switch {
@ -388,10 +395,12 @@ fn specialize_drops_stmt<'a, 'i>(
Stmt::Ret(symbol) => arena.alloc(Stmt::Ret(*symbol)),
Stmt::Refcounting(rc, continuation) => match rc {
ModifyRc::Inc(symbol, count) => {
let any = environment.any_incremented(symbol);
let any = environment.incremented_symbols.contains(symbol);
// Add a symbol for every increment performed.
environment.add_incremented(*symbol, *count);
environment
.incremented_symbols
.insert_count(*symbol, *count);
let new_continuation = specialize_drops_stmt(
arena,
@ -406,7 +415,12 @@ fn specialize_drops_stmt<'a, 'i>(
// Or there are no increments left, so we can just continue.
new_continuation
} else {
match environment.get_incremented(symbol) {
match environment
.incremented_symbols
.map
.remove(symbol)
.unwrap_or(0)
{
// This is the first increment, but all increments are consumed. So don't insert any.
0 => new_continuation,
// We still need to do some increments.
@ -427,7 +441,7 @@ fn specialize_drops_stmt<'a, 'i>(
// dec a
// dec b
if environment.pop_incremented(symbol) {
if environment.incremented_symbols.pop(symbol) {
// This decremented symbol was incremented before, so we can remove it.
specialize_drops_stmt(
arena,
@ -445,10 +459,10 @@ fn specialize_drops_stmt<'a, 'i>(
// As a might get dropped as a result of the decrement of b.
let mut incremented_children = {
let mut todo_children = bumpalo::vec![in arena; *symbol];
let mut incremented_children = MutSet::default();
let mut incremented_children = CountingMap::new();
while let Some(child) = todo_children.pop() {
if environment.pop_incremented(&child) {
if environment.incremented_symbols.pop(&child) {
incremented_children.insert(child);
} else {
todo_children.extend(environment.get_children(&child));
@ -519,8 +533,10 @@ fn specialize_drops_stmt<'a, 'i>(
};
// Add back the increments for the children to the environment.
for child_symbol in incremented_children.iter() {
environment.add_incremented(*child_symbol, 1)
for (child_symbol, symbol_count) in incremented_children.map.into_iter() {
environment
.incremented_symbols
.insert_count(child_symbol, symbol_count)
}
updated_stmt
@ -638,7 +654,7 @@ fn specialize_struct<'a, 'i>(
environment: &mut DropSpecializationEnvironment<'a>,
symbol: &Symbol,
struct_layout: &'a [InLayout],
incremented_children: &mut MutSet<Child>,
incremented_children: &mut CountingMap<Child>,
continuation: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
match environment.struct_children.get(symbol) {
@ -654,7 +670,7 @@ fn specialize_struct<'a, 'i>(
for (index, _layout) in struct_layout.iter().enumerate() {
for (child, _i) in children_clone.iter().filter(|(_, i)| *i == index as u64) {
let removed = incremented_children.remove(child);
let removed = incremented_children.pop(child);
index_symbols.insert(index, (*child, removed));
if removed {
@ -727,7 +743,7 @@ fn specialize_union<'a, 'i>(
environment: &mut DropSpecializationEnvironment<'a>,
symbol: &Symbol,
union_layout: UnionLayout<'a>,
incremented_children: &mut MutSet<Child>,
incremented_children: &mut CountingMap<Child>,
continuation: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
let current_tag = environment.symbol_tag.get(symbol).copied();
@ -770,7 +786,7 @@ fn specialize_union<'a, 'i>(
{
debug_assert_eq!(tag, *t);
let removed = incremented_children.remove(child);
let removed = incremented_children.pop(child);
index_symbols.insert(index, (*child, removed));
if removed {
@ -932,14 +948,14 @@ fn specialize_boxed<'a, 'i>(
layout_interner: &'i mut STLayoutInterner<'a>,
ident_ids: &'i mut IdentIds,
environment: &mut DropSpecializationEnvironment<'a>,
incremented_children: &mut MutSet<Child>,
incremented_children: &mut CountingMap<Child>,
symbol: &Symbol,
continuation: &'a Stmt<'a>,
) -> &'a Stmt<'a> {
let removed = match incremented_children.iter().next() {
Some(s) => {
let removed = match incremented_children.map.iter().next() {
Some((s, _)) => {
let s = *s;
incremented_children.remove(&s);
incremented_children.pop(&s);
Some(s)
}
None => None,
@ -989,7 +1005,7 @@ fn specialize_list<'a, 'i>(
layout_interner: &'i mut STLayoutInterner<'a>,
ident_ids: &'i mut IdentIds,
environment: &mut DropSpecializationEnvironment<'a>,
incremented_children: &mut MutSet<Child>,
incremented_children: &mut CountingMap<Child>,
symbol: &Symbol,
item_layout: InLayout,
continuation: &'a Stmt<'a>,
@ -1024,7 +1040,7 @@ fn specialize_list<'a, 'i>(
for (child, i) in children_clone.iter().filter(|(_child, i)| *i == index) {
debug_assert!(length > *i);
let removed = incremented_children.remove(child);
let removed = incremented_children.pop(child);
index_symbols.insert(index, (*child, removed));
if removed {
@ -1262,7 +1278,7 @@ struct DropSpecializationEnvironment<'a> {
list_children: MutMap<Parent, Vec<'a, (Child, Index)>>,
// Keeps track of all incremented symbols.
incremented_symbols: MutMap<Symbol, u64>,
incremented_symbols: CountingMap<Symbol>,
// Map containing the current known tag of a layout.
symbol_tag: MutMap<Symbol, Tag>,
@ -1286,7 +1302,7 @@ impl<'a> DropSpecializationEnvironment<'a> {
union_children: MutMap::default(),
box_children: MutMap::default(),
list_children: MutMap::default(),
incremented_symbols: MutMap::default(),
incremented_symbols: CountingMap::new(),
symbol_tag: MutMap::default(),
symbol_index: MutMap::default(),
list_length: MutMap::default(),
@ -1304,7 +1320,7 @@ impl<'a> DropSpecializationEnvironment<'a> {
union_children: self.union_children.clone(),
box_children: self.box_children.clone(),
list_children: self.list_children.clone(),
incremented_symbols: MutMap::default(),
incremented_symbols: CountingMap::new(),
symbol_tag: self.symbol_tag.clone(),
symbol_index: self.symbol_index.clone(),
list_length: self.list_length.clone(),
@ -1381,39 +1397,6 @@ impl<'a> DropSpecializationEnvironment<'a> {
res
}
/**
Add a symbol for every increment performed.
*/
fn add_incremented(&mut self, symbol: Symbol, count: u64) {
self.incremented_symbols
.entry(symbol)
.and_modify(|c| *c += count)
.or_insert(count);
}
fn any_incremented(&self, symbol: &Symbol) -> bool {
self.incremented_symbols.contains_key(symbol)
}
/**
Return the amount of times a symbol still has to be incremented.
Accounting for later consumtion and removal of the increment.
*/
fn get_incremented(&mut self, symbol: &Symbol) -> u64 {
self.incremented_symbols.remove(symbol).unwrap_or(0)
}
fn pop_incremented(&mut self, symbol: &Symbol) -> bool {
match self.incremented_symbols.get_mut(symbol) {
Some(0) => false,
Some(c) => {
*c -= 1;
true
}
None => false,
}
}
}
/**
@ -1525,3 +1508,51 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC {
}
}
}
#[derive(Clone)]
struct CountingMap<K>
where
K: Eq + std::hash::Hash,
{
map: MutMap<K, u64>,
}
impl<K> CountingMap<K>
where
K: Eq + std::hash::Hash,
{
fn new() -> Self {
Self {
map: MutMap::default(),
}
}
fn insert(&mut self, key: K) {
self.map.entry(key).and_modify(|c| *c += 1).or_insert(1);
}
fn insert_count(&mut self, key: K, count: u64) {
self.map
.entry(key)
.and_modify(|c| *c += count)
.or_insert(count);
}
fn pop(&mut self, key: &K) -> bool {
match self.map.get_mut(key) {
Some(1) => {
self.map.remove(key);
true
}
Some(c) => {
*c -= 1;
true
}
None => false,
}
}
fn contains(&self, symbol: &K) -> bool {
self.map.contains_key(symbol)
}
}