use bitvec-based topological sort

This commit is contained in:
Folkert 2022-04-22 11:35:08 +02:00
parent 7055645085
commit b6ccd9c8fb
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
2 changed files with 496 additions and 6 deletions

View file

@ -11,6 +11,8 @@ use crate::scope::create_alias;
use crate::scope::Scope;
use roc_collections::{default_hasher, ImEntry, ImMap, ImSet, MutMap, MutSet, SendMap};
use roc_module::ident::Lowercase;
use roc_module::symbol::IdentId;
use roc_module::symbol::ModuleId;
use roc_module::symbol::Symbol;
use roc_parse::ast;
use roc_parse::ast::AbilityMember;
@ -659,6 +661,407 @@ pub fn canonicalize_defs<'a>(
)
}
#[derive(Clone, Copy)]
struct DefId(u32);
#[derive(Debug)]
struct DefIds {
home: ModuleId,
symbol_to_id: Vec<(IdentId, u32)>,
// an length x length matrix indicating who references who
references: bitvec::vec::BitVec<u8>,
length: u32,
}
impl DefIds {
fn with_capacity(home: ModuleId, capacity: usize) -> Self {
use bitvec::vec::BitVec;
// makes each new row start at a multiple of 8
// let hack = capacity + (8 - capacity % 8);
let references = BitVec::repeat(false, capacity * capacity);
Self {
home,
symbol_to_id: Vec::with_capacity(capacity),
references,
length: capacity as u32,
}
}
fn from_defs_by_symbol(
env: &Env,
can_defs_by_symbol: &MutMap<Symbol, Def>,
refs_by_symbol: &MutMap<Symbol, (Region, References)>,
) -> Self {
let mut this = Self::with_capacity(env.home, can_defs_by_symbol.len());
for (i, symbol) in can_defs_by_symbol.keys().enumerate() {
debug_assert_eq!(env.home, symbol.module_id());
this.symbol_to_id.push((symbol.ident_id(), i as u32));
}
for (symbol, (_, references)) in refs_by_symbol.iter() {
let def_id = DefId(this.get_id(*symbol).unwrap());
for referenced in references.value_lookups() {
this.register_reference(def_id, *referenced);
}
for referenced in references.calls() {
this.register_reference(def_id, *referenced);
}
if let Some(references) = env.closures.get(symbol) {
for referenced in references.value_lookups() {
this.register_reference(def_id, *referenced);
}
for referenced in references.calls() {
this.register_reference(def_id, *referenced);
}
}
}
this
}
fn get_id(&self, symbol: Symbol) -> Option<u32> {
self.symbol_to_id
.iter()
.find(|(id, _)| *id == symbol.ident_id())
.map(|t| t.1)
}
fn get_symbol(&self, id: u32) -> Option<Symbol> {
self.symbol_to_id
.iter()
.find(|(_, def_id)| id == *def_id)
.map(|t| Symbol::new(self.home, t.0))
}
fn register_reference(&mut self, id: DefId, referenced: Symbol) -> bool {
if referenced.module_id() != self.home {
return false;
}
match self.get_id(referenced) {
None => {
// this symbol is not defined within the let-block that this DefIds represents
false
}
Some(referenced_id) => {
let row = id.0;
let column = referenced_id;
let index = row * self.length + column;
self.references.set(index as usize, true);
true
}
}
}
fn calls_itself_directly(&self, id: u32) -> bool {
let row = &self.references[(id * self.length) as usize..][..self.length as usize];
row[id as usize]
}
#[inline(always)]
fn successors(&self, id: u32) -> impl Iterator<Item = u32> + '_ {
let row = &self.references[(id * self.length) as usize..][..self.length as usize];
row.iter_ones().map(|x| x as u32)
}
#[inline(always)]
fn successors_without_self(&self, id: u32) -> impl Iterator<Item = u32> + '_ {
self.successors(id).filter(move |x| *x != id)
}
#[allow(clippy::type_complexity)]
fn topological_sort_into_groups(&self) -> Result<Vec<Vec<u32>>, (Vec<Vec<u32>>, Vec<u32>)> {
let length = self.length as usize;
let bitvec = &self.references;
if length == 0 {
return Ok(Vec::new());
}
let mut preds_map: Vec<i64> = vec![0; length];
// this is basically summing the columns, I don't see a better way to do it
for row in bitvec.chunks(length) {
for succ in row.iter_ones() {
preds_map[succ] += 1;
}
}
let mut groups = Vec::<Vec<u32>>::new();
// the initial group contains all symbols with no predecessors
let mut prev_group: Vec<u32> = preds_map
.iter()
.enumerate()
.filter_map(|(node, &num_preds)| {
if num_preds == 0 {
Some(node as u32)
} else {
None
}
})
.collect();
if prev_group.is_empty() {
let remaining: Vec<u32> = (0u32..length as u32).collect();
return Err((Vec::new(), remaining));
}
// NOTE: the original now removes elements from the preds_map if they have count 0
// for node in &prev_group {
// preds_map.remove(node);
// }
while preds_map.iter().any(|x| *x > 0) {
let mut next_group = Vec::<u32>::new();
for node in &prev_group {
let row = &bitvec[length * (*node as usize)..][..length];
for succ in row.iter_ones() {
{
let num_preds = preds_map.get_mut(succ).unwrap();
*num_preds = num_preds.saturating_sub(1);
if *num_preds > 0 {
continue;
}
}
let count = preds_map[succ];
preds_map[succ] = -1;
if count > -1 {
next_group.push(succ as u32);
}
}
}
groups.push(std::mem::replace(&mut prev_group, next_group));
if prev_group.is_empty() {
let remaining: Vec<u32> = (0u32..length as u32)
.filter(|i| preds_map[*i as usize] > 0)
.collect();
return Err((groups, remaining));
}
}
groups.push(prev_group);
Ok(groups)
}
#[allow(dead_code)]
fn debug_relations(&self) {
for id in 0u32..self.length as u32 {
let row = &self.references[(id * self.length) as usize..][..self.length as usize];
let matches = row
.iter()
.enumerate()
.filter(move |t| *t.1)
.map(|t| t.0 as u32);
for m in matches {
let a = self.get_symbol(id).unwrap();
let b = self.get_symbol(m).unwrap();
println!("{:?} <- {:?}", a, b);
}
}
}
}
#[inline(always)]
pub fn sort_can_defs_improved(
env: &mut Env<'_>,
defs: CanDefs,
mut output: Output,
) -> (Result<Vec<Declaration>, RuntimeError>, Output) {
let def_ids = DefIds::from_defs_by_symbol(env, &defs.can_defs_by_symbol, &defs.refs_by_symbol);
let CanDefs {
refs_by_symbol,
mut can_defs_by_symbol,
aliases,
} = defs;
for (symbol, alias) in aliases.into_iter() {
output.aliases.insert(symbol, alias);
}
// TODO also do the same `addDirects` check elm/compiler does, so we can
// report an error if a recursive definition can't possibly terminate!
match def_ids.topological_sort_into_groups() {
Ok(groups) => {
let mut declarations = Vec::new();
// groups are in reversed order
for group in groups.into_iter().rev() {
group_to_declaration_improved(
&def_ids,
&group,
&env.closures,
&mut can_defs_by_symbol,
&mut declarations,
);
}
(Ok(declarations), output)
}
Err((mut groups, nodes_in_cycle)) => {
let mut declarations = Vec::new();
let mut problems = Vec::new();
// nodes_in_cycle are symbols that form a syntactic cycle. That isn't always a problem,
// and in general it's impossible to decide whether it is. So we use a crude heuristic:
//
// Definitions where the cycle occurs behind a lambda are OK
//
// boom = \_ -> boom {}
//
// But otherwise we report an error, e.g.
//
// foo = if b then foo else bar
let all_successors_without_self = |id: &u32| {
let id = *id;
def_ids.successors_without_self(id)
};
for cycle in strongly_connected_components(&nodes_in_cycle, all_successors_without_self)
{
// check whether the cycle is faulty, which is when it has
// a direct successor in the current cycle. This catches things like:
//
// x = x
//
// or
//
// p = q
// q = p
let is_invalid_cycle = match cycle.get(0) {
Some(def_id) => def_ids.successors(*def_id).any(|key| cycle.contains(&key)),
None => false,
};
if is_invalid_cycle {
// We want to show the entire cycle in the error message, so expand it out.
let mut entries = Vec::new();
for def_id in &cycle {
let symbol = def_ids.get_symbol(*def_id).unwrap();
match refs_by_symbol.get(&symbol) {
None => unreachable!(
r#"Symbol `{:?}` not found in refs_by_symbol! refs_by_symbol was: {:?}"#,
symbol, refs_by_symbol
),
Some((region, _)) => {
let expr_region =
can_defs_by_symbol.get(&symbol).unwrap().loc_expr.region;
let entry = CycleEntry {
symbol,
symbol_region: *region,
expr_region,
};
entries.push(entry);
}
}
}
// Sort them by line number to make the report more helpful.
entries.sort_by_key(|entry| entry.symbol_region);
problems.push(Problem::RuntimeError(RuntimeError::CircularDef(
entries.clone(),
)));
declarations.push(Declaration::InvalidCycle(entries));
}
// if it's an invalid cycle, other groups may depend on the
// symbols defined here, so also push this cycle onto the groups
//
// if it's not an invalid cycle, this is slightly inefficient,
// because we know this becomes exactly one DeclareRec already
groups.push(cycle);
}
// now we have a collection of groups whose dependencies are not cyclic.
// They are however not yet topologically sorted. Here we have to get a bit
// creative to get all the definitions in the correct sorted order.
let mut group_ids = Vec::with_capacity(groups.len());
let mut symbol_to_group_index = MutMap::default();
for (i, group) in groups.iter().enumerate() {
for symbol in group {
symbol_to_group_index.insert(*symbol, i);
}
group_ids.push(i);
}
let successors_of_group = |group_id: &usize| {
let mut result = MutSet::default();
// for each symbol in this group
for symbol in &groups[*group_id] {
// find its successors
for succ in all_successors_without_self(symbol) {
// and add its group to the result
match symbol_to_group_index.get(&succ) {
Some(index) => {
result.insert(*index);
}
None => unreachable!("no index for symbol {:?}", succ),
}
}
}
// don't introduce any cycles to self
result.remove(group_id);
result
};
match ven_graph::topological_sort_into_groups(&group_ids, successors_of_group) {
Ok(sorted_group_ids) => {
for sorted_group in sorted_group_ids.iter().rev() {
for group_id in sorted_group.iter().rev() {
let group = &groups[*group_id];
group_to_declaration_improved(
&def_ids,
group,
&env.closures,
&mut can_defs_by_symbol,
&mut declarations,
);
}
}
}
Err(_) => unreachable!("there should be no cycles now!"),
}
for problem in problems {
env.problem(problem);
}
(Ok(declarations), output)
}
}
}
#[inline(always)]
pub fn sort_can_defs(
env: &mut Env<'_>,
@ -807,12 +1210,14 @@ pub fn sort_can_defs(
}
};
// TODO also do the same `addDirects` check elm/compiler does, so we can
// report an error if a recursive definition can't possibly terminate!
match ven_graph::topological_sort_into_groups(
let grouped2 = ven_graph::topological_sort_into_groups(
defined_symbols.as_slice(),
all_successors_without_self,
) {
);
// TODO also do the same `addDirects` check elm/compiler does, so we can
// report an error if a recursive definition can't possibly terminate!
match grouped2 {
Ok(groups) => {
let mut declarations = Vec::new();
@ -1060,6 +1465,89 @@ fn group_to_declaration(
}
}
fn group_to_declaration_improved(
def_ids: &DefIds,
group: &[u32],
closures: &MutMap<Symbol, References>,
can_defs_by_symbol: &mut MutMap<Symbol, Def>,
declarations: &mut Vec<Declaration>,
) {
use Declaration::*;
// We want only successors in the current group, otherwise definitions get duplicated
let filtered_successors = |id: &u32| def_ids.successors(*id).filter(|key| group.contains(key));
// Patterns like
//
// { x, y } = someDef
//
// Can bind multiple symbols. When not incorrectly recursive (which is guaranteed in this function),
// normally `someDef` would be inserted twice. We use the region of the pattern as a unique key
// for a definition, so every definition is only inserted (thus typechecked and emitted) once
let mut seen_pattern_regions: Vec<Region> = Vec::with_capacity(2);
for cycle in strongly_connected_components(group, filtered_successors) {
if cycle.len() == 1 {
let def_id = cycle[0];
let symbol = def_ids.get_symbol(def_id).unwrap();
match can_defs_by_symbol.remove(&symbol) {
Some(mut new_def) => {
// Determine recursivity of closures that are not tail-recursive
if let Closure(ClosureData {
recursive: recursive @ Recursive::NotRecursive,
..
}) = &mut new_def.loc_expr.value
{
*recursive = closure_recursivity(symbol, closures);
}
let is_recursive = def_ids.calls_itself_directly(def_id);
if !seen_pattern_regions.contains(&new_def.loc_pattern.region) {
seen_pattern_regions.push(new_def.loc_pattern.region);
if is_recursive {
declarations.push(DeclareRec(vec![new_def]));
} else {
declarations.push(Declare(new_def));
}
}
}
None => roc_error_macros::internal_error!("def not available {:?}", symbol),
}
} else {
let mut can_defs = Vec::new();
// Topological sort gives us the reverse of the sorting we want!
for def_id in cycle.into_iter().rev() {
let symbol = def_ids.get_symbol(def_id).unwrap();
match can_defs_by_symbol.remove(&symbol) {
Some(mut new_def) => {
// Determine recursivity of closures that are not tail-recursive
if let Closure(ClosureData {
recursive: recursive @ Recursive::NotRecursive,
..
}) = &mut new_def.loc_expr.value
{
*recursive = closure_recursivity(symbol, closures);
}
if !seen_pattern_regions.contains(&new_def.loc_pattern.region) {
seen_pattern_regions.push(new_def.loc_pattern.region);
can_defs.push(new_def);
}
}
None => roc_error_macros::internal_error!("def not available {:?}", symbol),
}
}
declarations.push(DeclareRec(can_defs));
}
}
}
fn pattern_to_vars_by_symbol(
vars_by_symbol: &mut SendMap<Symbol, Variable>,
pattern: &Pattern,