mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-27 05:49:08 +00:00
using CountingMap for children
This commit is contained in:
parent
8f022d4310
commit
dbebaf17a6
1 changed files with 90 additions and 59 deletions
|
@ -27,7 +27,7 @@ use crate::layout::{
|
|||
|
||||
use bumpalo::Bump;
|
||||
|
||||
use roc_collections::{MutMap, MutSet};
|
||||
use roc_collections::MutMap;
|
||||
|
||||
/**
|
||||
Try to find increments of symbols followed by decrements of the symbol they were indexed out of (their parent).
|
||||
|
@ -317,10 +317,12 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
};
|
||||
|
||||
// Find the lowest symbol count for each symbol in each branch, and update the environment to match.
|
||||
for (symbol, count) in environment.incremented_symbols.iter_mut() {
|
||||
for (symbol, count) in environment.incremented_symbols.map.iter_mut() {
|
||||
let consumed = branch_envs
|
||||
.iter()
|
||||
.map(|branch_env| branch_env.incremented_symbols.get(symbol).unwrap_or(&0))
|
||||
.map(|branch_env| {
|
||||
branch_env.incremented_symbols.map.get(symbol).unwrap_or(&0)
|
||||
})
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
|
@ -334,10 +336,14 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
let symbol_differences =
|
||||
environment
|
||||
.incremented_symbols
|
||||
.map
|
||||
.iter()
|
||||
.filter_map(|(symbol, count)| {
|
||||
let branch_count =
|
||||
$branch_env.incremented_symbols.get(symbol).unwrap_or(&0);
|
||||
let branch_count = $branch_env
|
||||
.incremented_symbols
|
||||
.map
|
||||
.get(symbol)
|
||||
.unwrap_or(&0);
|
||||
|
||||
match branch_count - count {
|
||||
0 => None,
|
||||
|
@ -375,6 +381,7 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
// Remove all 0 counts as cleanup.
|
||||
environment
|
||||
.incremented_symbols
|
||||
.map
|
||||
.retain(|_, count| *count > 0);
|
||||
|
||||
arena.alloc(Stmt::Switch {
|
||||
|
@ -388,10 +395,12 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
Stmt::Ret(symbol) => arena.alloc(Stmt::Ret(*symbol)),
|
||||
Stmt::Refcounting(rc, continuation) => match rc {
|
||||
ModifyRc::Inc(symbol, count) => {
|
||||
let any = environment.any_incremented(symbol);
|
||||
let any = environment.incremented_symbols.contains(symbol);
|
||||
|
||||
// Add a symbol for every increment performed.
|
||||
environment.add_incremented(*symbol, *count);
|
||||
environment
|
||||
.incremented_symbols
|
||||
.insert_count(*symbol, *count);
|
||||
|
||||
let new_continuation = specialize_drops_stmt(
|
||||
arena,
|
||||
|
@ -406,7 +415,12 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
// Or there are no increments left, so we can just continue.
|
||||
new_continuation
|
||||
} else {
|
||||
match environment.get_incremented(symbol) {
|
||||
match environment
|
||||
.incremented_symbols
|
||||
.map
|
||||
.remove(symbol)
|
||||
.unwrap_or(0)
|
||||
{
|
||||
// This is the first increment, but all increments are consumed. So don't insert any.
|
||||
0 => new_continuation,
|
||||
// We still need to do some increments.
|
||||
|
@ -427,7 +441,7 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
// dec a
|
||||
// dec b
|
||||
|
||||
if environment.pop_incremented(symbol) {
|
||||
if environment.incremented_symbols.pop(symbol) {
|
||||
// This decremented symbol was incremented before, so we can remove it.
|
||||
specialize_drops_stmt(
|
||||
arena,
|
||||
|
@ -445,10 +459,10 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
// As a might get dropped as a result of the decrement of b.
|
||||
let mut incremented_children = {
|
||||
let mut todo_children = bumpalo::vec![in arena; *symbol];
|
||||
let mut incremented_children = MutSet::default();
|
||||
let mut incremented_children = CountingMap::new();
|
||||
|
||||
while let Some(child) = todo_children.pop() {
|
||||
if environment.pop_incremented(&child) {
|
||||
if environment.incremented_symbols.pop(&child) {
|
||||
incremented_children.insert(child);
|
||||
} else {
|
||||
todo_children.extend(environment.get_children(&child));
|
||||
|
@ -519,8 +533,10 @@ fn specialize_drops_stmt<'a, 'i>(
|
|||
};
|
||||
|
||||
// Add back the increments for the children to the environment.
|
||||
for child_symbol in incremented_children.iter() {
|
||||
environment.add_incremented(*child_symbol, 1)
|
||||
for (child_symbol, symbol_count) in incremented_children.map.into_iter() {
|
||||
environment
|
||||
.incremented_symbols
|
||||
.insert_count(child_symbol, symbol_count)
|
||||
}
|
||||
|
||||
updated_stmt
|
||||
|
@ -638,7 +654,7 @@ fn specialize_struct<'a, 'i>(
|
|||
environment: &mut DropSpecializationEnvironment<'a>,
|
||||
symbol: &Symbol,
|
||||
struct_layout: &'a [InLayout],
|
||||
incremented_children: &mut MutSet<Child>,
|
||||
incremented_children: &mut CountingMap<Child>,
|
||||
continuation: &'a Stmt<'a>,
|
||||
) -> &'a Stmt<'a> {
|
||||
match environment.struct_children.get(symbol) {
|
||||
|
@ -654,7 +670,7 @@ fn specialize_struct<'a, 'i>(
|
|||
|
||||
for (index, _layout) in struct_layout.iter().enumerate() {
|
||||
for (child, _i) in children_clone.iter().filter(|(_, i)| *i == index as u64) {
|
||||
let removed = incremented_children.remove(child);
|
||||
let removed = incremented_children.pop(child);
|
||||
index_symbols.insert(index, (*child, removed));
|
||||
|
||||
if removed {
|
||||
|
@ -727,7 +743,7 @@ fn specialize_union<'a, 'i>(
|
|||
environment: &mut DropSpecializationEnvironment<'a>,
|
||||
symbol: &Symbol,
|
||||
union_layout: UnionLayout<'a>,
|
||||
incremented_children: &mut MutSet<Child>,
|
||||
incremented_children: &mut CountingMap<Child>,
|
||||
continuation: &'a Stmt<'a>,
|
||||
) -> &'a Stmt<'a> {
|
||||
let current_tag = environment.symbol_tag.get(symbol).copied();
|
||||
|
@ -770,7 +786,7 @@ fn specialize_union<'a, 'i>(
|
|||
{
|
||||
debug_assert_eq!(tag, *t);
|
||||
|
||||
let removed = incremented_children.remove(child);
|
||||
let removed = incremented_children.pop(child);
|
||||
index_symbols.insert(index, (*child, removed));
|
||||
|
||||
if removed {
|
||||
|
@ -932,14 +948,14 @@ fn specialize_boxed<'a, 'i>(
|
|||
layout_interner: &'i mut STLayoutInterner<'a>,
|
||||
ident_ids: &'i mut IdentIds,
|
||||
environment: &mut DropSpecializationEnvironment<'a>,
|
||||
incremented_children: &mut MutSet<Child>,
|
||||
incremented_children: &mut CountingMap<Child>,
|
||||
symbol: &Symbol,
|
||||
continuation: &'a Stmt<'a>,
|
||||
) -> &'a Stmt<'a> {
|
||||
let removed = match incremented_children.iter().next() {
|
||||
Some(s) => {
|
||||
let removed = match incremented_children.map.iter().next() {
|
||||
Some((s, _)) => {
|
||||
let s = *s;
|
||||
incremented_children.remove(&s);
|
||||
incremented_children.pop(&s);
|
||||
Some(s)
|
||||
}
|
||||
None => None,
|
||||
|
@ -989,7 +1005,7 @@ fn specialize_list<'a, 'i>(
|
|||
layout_interner: &'i mut STLayoutInterner<'a>,
|
||||
ident_ids: &'i mut IdentIds,
|
||||
environment: &mut DropSpecializationEnvironment<'a>,
|
||||
incremented_children: &mut MutSet<Child>,
|
||||
incremented_children: &mut CountingMap<Child>,
|
||||
symbol: &Symbol,
|
||||
item_layout: InLayout,
|
||||
continuation: &'a Stmt<'a>,
|
||||
|
@ -1024,7 +1040,7 @@ fn specialize_list<'a, 'i>(
|
|||
for (child, i) in children_clone.iter().filter(|(_child, i)| *i == index) {
|
||||
debug_assert!(length > *i);
|
||||
|
||||
let removed = incremented_children.remove(child);
|
||||
let removed = incremented_children.pop(child);
|
||||
index_symbols.insert(index, (*child, removed));
|
||||
|
||||
if removed {
|
||||
|
@ -1262,7 +1278,7 @@ struct DropSpecializationEnvironment<'a> {
|
|||
list_children: MutMap<Parent, Vec<'a, (Child, Index)>>,
|
||||
|
||||
// Keeps track of all incremented symbols.
|
||||
incremented_symbols: MutMap<Symbol, u64>,
|
||||
incremented_symbols: CountingMap<Symbol>,
|
||||
|
||||
// Map containing the current known tag of a layout.
|
||||
symbol_tag: MutMap<Symbol, Tag>,
|
||||
|
@ -1286,7 +1302,7 @@ impl<'a> DropSpecializationEnvironment<'a> {
|
|||
union_children: MutMap::default(),
|
||||
box_children: MutMap::default(),
|
||||
list_children: MutMap::default(),
|
||||
incremented_symbols: MutMap::default(),
|
||||
incremented_symbols: CountingMap::new(),
|
||||
symbol_tag: MutMap::default(),
|
||||
symbol_index: MutMap::default(),
|
||||
list_length: MutMap::default(),
|
||||
|
@ -1304,7 +1320,7 @@ impl<'a> DropSpecializationEnvironment<'a> {
|
|||
union_children: self.union_children.clone(),
|
||||
box_children: self.box_children.clone(),
|
||||
list_children: self.list_children.clone(),
|
||||
incremented_symbols: MutMap::default(),
|
||||
incremented_symbols: CountingMap::new(),
|
||||
symbol_tag: self.symbol_tag.clone(),
|
||||
symbol_index: self.symbol_index.clone(),
|
||||
list_length: self.list_length.clone(),
|
||||
|
@ -1381,39 +1397,6 @@ impl<'a> DropSpecializationEnvironment<'a> {
|
|||
|
||||
res
|
||||
}
|
||||
|
||||
/**
|
||||
Add a symbol for every increment performed.
|
||||
*/
|
||||
fn add_incremented(&mut self, symbol: Symbol, count: u64) {
|
||||
self.incremented_symbols
|
||||
.entry(symbol)
|
||||
.and_modify(|c| *c += count)
|
||||
.or_insert(count);
|
||||
}
|
||||
|
||||
fn any_incremented(&self, symbol: &Symbol) -> bool {
|
||||
self.incremented_symbols.contains_key(symbol)
|
||||
}
|
||||
|
||||
/**
|
||||
Return the amount of times a symbol still has to be incremented.
|
||||
Accounting for later consumtion and removal of the increment.
|
||||
*/
|
||||
fn get_incremented(&mut self, symbol: &Symbol) -> u64 {
|
||||
self.incremented_symbols.remove(symbol).unwrap_or(0)
|
||||
}
|
||||
|
||||
fn pop_incremented(&mut self, symbol: &Symbol) -> bool {
|
||||
match self.incremented_symbols.get_mut(symbol) {
|
||||
Some(0) => false,
|
||||
Some(c) => {
|
||||
*c -= 1;
|
||||
true
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1525,3 +1508,51 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CountingMap<K>
|
||||
where
|
||||
K: Eq + std::hash::Hash,
|
||||
{
|
||||
map: MutMap<K, u64>,
|
||||
}
|
||||
|
||||
impl<K> CountingMap<K>
|
||||
where
|
||||
K: Eq + std::hash::Hash,
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
map: MutMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(&mut self, key: K) {
|
||||
self.map.entry(key).and_modify(|c| *c += 1).or_insert(1);
|
||||
}
|
||||
|
||||
fn insert_count(&mut self, key: K, count: u64) {
|
||||
self.map
|
||||
.entry(key)
|
||||
.and_modify(|c| *c += count)
|
||||
.or_insert(count);
|
||||
}
|
||||
|
||||
fn pop(&mut self, key: &K) -> bool {
|
||||
match self.map.get_mut(key) {
|
||||
Some(1) => {
|
||||
self.map.remove(key);
|
||||
true
|
||||
}
|
||||
Some(c) => {
|
||||
*c -= 1;
|
||||
true
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn contains(&self, symbol: &K) -> bool {
|
||||
self.map.contains_key(symbol)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue