Simplified logic

This commit is contained in:
J.Teeuwissen 2023-06-07 17:56:00 +02:00 committed by Folkert
parent 46bff75517
commit fbf3faeaf1
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
2 changed files with 191 additions and 239 deletions

View file

@ -650,6 +650,9 @@ fn specialize_drops_stmt<'a, 'i>(
body,
remainder,
} => {
// We cannot perform this optimization if the joinpoint is recursive.
// E.g. if the body of a recursive joinpoint contains an increment, we do not want to move that increment up to the remainder.
let mut remainder_enviroment = environment.clone();
let new_remainder = specialize_drops_stmt(
@ -660,27 +663,32 @@ fn specialize_drops_stmt<'a, 'i>(
remainder,
);
let jump_symbols_option = remainder_enviroment
.jump_incremented_symbols
.get(id)
.and_then(|jump_symbols| {
(!jump_symbols.is_empty()).then_some(jump_symbols.clone())
});
let mut body_environment = environment.clone();
for param in parameters.iter() {
body_environment.add_symbol_layout(param.symbol, param.layout);
}
body_environment.incremented_symbols.clear();
let (new_body, newer_remainder) = match jump_symbols_option {
// The remainder has no jumps with symbols to this joinpoint.
None => {
// Determine the body with no usage.
body_environment
.join_incremented_symbols
.insert(*id, JoinUsage::NoUsage);
let new_body = specialize_drops_stmt(
let new_body = specialize_drops_stmt(
arena,
layout_interner,
ident_ids,
&mut body_environment,
body,
);
let remainder_jump_info = remainder_enviroment.jump_incremented_symbols.get(id);
let body_jump_info = body_environment.jump_incremented_symbols.get(id);
let (newer_body, newer_remainder) = match (remainder_jump_info, body_jump_info) {
// We have info from the remainder, and the body is not recursive.
// Meaning we can pass the incremented_symbols from the remainder to the body.
(Some(jump_info), None) if !jump_info.is_empty() => {
// Update body with incremented symbols from remainder
body_environment.incremented_symbols = jump_info.clone();
let newer_body = specialize_drops_stmt(
arena,
layout_interner,
ident_ids,
@ -688,68 +696,14 @@ fn specialize_drops_stmt<'a, 'i>(
body,
);
// Keep the remainder as is.
*environment = remainder_enviroment;
(new_body, new_remainder)
}
// The remainder has jumps with incremented symbols. We need to perform iteration.
Some(remainder_jump_incremented_symbols) => {
// Symbols the join consumes, decreases by iterating.
let mut join_consumes = remainder_jump_incremented_symbols;
// Symbols the join returns, increases by iterating.
let mut join_returns = CountingMap::new();
// Perform iteration to get the incremented_symbols for the body.
let (new_body, alloced_joinpoint_info) = loop {
// Update the incremented_symbols to the remainder's incremented_symbols.
let mut current_body_environment = body_environment.clone();
current_body_environment.incremented_symbols = join_consumes.clone();
let joinpoint_info = JoinUsage::HasUsage {
join_consumes: join_consumes.clone(),
join_returns: join_returns.clone(),
};
current_body_environment
.join_incremented_symbols
.insert(*id, joinpoint_info);
// TODO make sure fixed point only shrinks.
let new_body = specialize_drops_stmt(
arena,
layout_interner,
ident_ids,
&mut current_body_environment,
body,
);
let new_join_consumes = current_body_environment
.jump_incremented_symbols
.get(id)
.map_or(join_consumes.clone(), |new_join_consumes| {
new_join_consumes.clone()
});
let new_join_returns = current_body_environment.incremented_symbols;
if join_consumes == new_join_consumes && join_returns == new_join_returns {
let new_joinpoint_info = JoinUsage::HasUsage {
join_consumes: new_join_consumes,
join_returns: new_join_returns,
};
break (new_body, new_joinpoint_info);
} else {
join_consumes = new_join_consumes;
join_returns = new_join_returns;
}
};
// Update remainder
environment
.join_incremented_symbols
.insert(*id, alloced_joinpoint_info);
environment.join_incremented_symbols.insert(
*id,
JoinUsage {
join_consumes: jump_info.clone(),
join_returns: body_environment.incremented_symbols,
},
);
let newer_remainder = specialize_drops_stmt(
arena,
layout_interner,
@ -758,45 +712,46 @@ fn specialize_drops_stmt<'a, 'i>(
remainder,
);
(new_body, newer_remainder)
(newer_body, newer_remainder)
}
_ => {
// Keep the body and remainder as is.
// Update the environment with remainder environment.
*environment = remainder_enviroment;
(new_body, new_remainder)
}
};
arena.alloc(Stmt::Join {
id: *id,
parameters,
body: new_body,
body: newer_body,
remainder: newer_remainder,
})
}
Stmt::Jump(joinpoint_id, arguments) => {
match environment.join_incremented_symbols.get(joinpoint_id) {
Some(join_usage) => {
match join_usage {
JoinUsage::NoUsage => {
// Do nothing.
}
JoinUsage::HasUsage {
join_consumes,
join_returns,
} => {
// Consume all symbols that were consumed in the join.
for (symbol, count) in join_consumes.map.iter() {
for _ in 0..*count {
let popped = environment.incremented_symbols.pop(symbol);
debug_assert!(
popped,
"Every incremented symbol should be available from jumps"
);
}
}
for (symbol, count) in join_returns.map.iter() {
environment
.incremented_symbols
.insert_count(*symbol, *count);
}
Some(JoinUsage {
join_consumes,
join_returns,
}) => {
// Consume all symbols that were consumed in the join.
for (symbol, count) in join_consumes.map.iter() {
for _ in 0..*count {
let popped = environment.incremented_symbols.pop(symbol);
debug_assert!(
popped,
"Every incremented symbol should be available from jumps"
);
}
}
for (symbol, count) in join_returns.map.iter() {
environment
.incremented_symbols
.insert_count(*symbol, *count);
}
}
None => {
// No join usage, let the join know the minimum amount of symbols that were incremented from each jump.
@ -1465,12 +1420,9 @@ struct DropSpecializationEnvironment<'a> {
}
#[derive(Clone)]
enum JoinUsage {
NoUsage,
HasUsage {
join_consumes: CountingMap<Symbol>,
join_returns: CountingMap<Symbol>,
},
struct JoinUsage {
join_consumes: CountingMap<Symbol>,
join_returns: CountingMap<Symbol>,
}
impl<'a> DropSpecializationEnvironment<'a> {