diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 7c8aabdb00..9b9e46ef8f 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -23,17 +23,19 @@ pub fn infer_borrow<'a>( // intern the layouts let mut declaration_to_index = MutMap::with_capacity_and_hasher(procs.len(), default_hasher()); - let mut i = 0; - for key in procs.keys() { - declaration_to_index.insert(*key, DeclarationId(i)); + let mut param_map = { + let mut i = 0; + for key in procs.keys() { + declaration_to_index.insert(*key, ParamOffset(i)); - i += key.1.arguments.len(); - } + i += key.1.arguments.len(); + } - let mut param_map = ParamMap { - declaration_to_index, - join_points: MutMap::default(), - declarations: bumpalo::vec![in arena; Param::EMPTY; i], + ParamMap { + declaration_to_index, + join_points: MutMap::default(), + declarations: bumpalo::vec![in arena; Param::EMPTY; i], + } }; for (key, proc) in procs { @@ -48,28 +50,110 @@ pub fn infer_borrow<'a>( arena, }; - // This is a fixed-point analysis - // - // all functions initiall own all their parameters - // through a series of checks and heuristics, some arguments are set to borrowed - // when that doesn't lead to conflicts the change is kept, otherwise it may be reverted - // - // when the signatures no longer change, the analysis stops and returns the signatures - loop { - // sort the symbols (roughly) in definition order. - // TODO in the future I think we need to do this properly, and group - // mutually recursive functions (or just make all their arguments owned) + // next we first partition the functions into strongly connected components, then do a + // topological sort on these components, finally run the fix-point borrow analysis on each + // component (in top-sorted order, from primitives (std-lib) to main) - for (key, proc) in procs { - env.collect_proc(&mut param_map, proc, key.1); + let successor_map = &make_successor_mapping(arena, procs); + let successors = move |key: &Symbol| successor_map[key].iter().copied(); + + let mut symbols = Vec::with_capacity_in(procs.len(), arena); + symbols.extend(procs.keys().map(|x| x.0)); + + let sccs = ven_graph::strongly_connected_components(&symbols, successors); + + let mut symbol_to_component = MutMap::default(); + for (i, symbols) in sccs.iter().enumerate() { + for symbol in symbols { + symbol_to_component.insert(*symbol, i); + } + } + + let mut component_to_successors = Vec::with_capacity_in(sccs.len(), arena); + for (i, symbols) in sccs.iter().enumerate() { + // guess: every function has ~1 successor + let mut succs = Vec::with_capacity_in(symbols.len(), arena); + + for symbol in symbols { + for s in successors(symbol) { + let c = symbol_to_component[&s]; + + // don't insert self to prevent cycles + if c != i { + succs.push(c); + } + } } - if !env.modified { - // if there were no modifications, we're done - break; - } else { - // otherwise see if there are changes after another iteration - env.modified = false; + succs.sort_unstable(); + succs.dedup(); + + component_to_successors.push(succs); + } + + let mut components = Vec::with_capacity_in(component_to_successors.len(), arena); + components.extend(0..component_to_successors.len()); + + let mut groups = Vec::new_in(arena); + + let component_to_successors = &component_to_successors; + match ven_graph::topological_sort_into_groups(&components, |c: &usize| { + component_to_successors[*c].iter().copied() + }) { + Ok(component_groups) => { + let mut component_to_group = bumpalo::vec![in arena; usize::MAX; components.len()]; + + // for each component, store which group it is in + for (group_index, component_group) in component_groups.iter().enumerate() { + for component in component_group { + component_to_group[*component] = group_index; + } + } + + // prepare groups + groups.reserve(component_groups.len()); + for _ in 0..component_groups.len() { + groups.push(Vec::new_in(arena)); + } + + for (key, proc) in procs { + let symbol = key.0; + let offset = param_map.declaration_to_index[key]; + + // the component this symbol is a part of + let component = symbol_to_component[&symbol]; + + // now find the group that this component belongs to + let group = component_to_group[component]; + + groups[group].push((proc, offset)); + } + } + Err((_groups, _remainder)) => { + unreachable!("because we find strongly-connected components first"); + } + } + + for group in groups.into_iter().rev() { + // This is a fixed-point analysis + // + // all functions initiall own all their parameters + // through a series of checks and heuristics, some arguments are set to borrowed + // when that doesn't lead to conflicts the change is kept, otherwise it may be reverted + // + // when the signatures no longer change, the analysis stops and returns the signatures + loop { + for (proc, param_offset) in group.iter() { + env.collect_proc(&mut param_map, proc, *param_offset); + } + + if !env.modified { + // if there were no modifications, we're done + break; + } else { + // otherwise see if there are changes after another iteration + env.modified = false; + } } } @@ -77,10 +161,10 @@ pub fn infer_borrow<'a>( } #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -pub struct DeclarationId(usize); +pub struct ParamOffset(usize); -impl From for usize { - fn from(id: DeclarationId) -> Self { +impl From for usize { + fn from(id: ParamOffset) -> Self { id.0 as usize } } @@ -88,7 +172,7 @@ impl From for usize { #[derive(Debug, Clone)] pub struct ParamMap<'a> { /// Map a (Symbol, ProcLayout) pair to the starting index in the `declarations` array - declaration_to_index: MutMap<(Symbol, ProcLayout<'a>), DeclarationId>, + declaration_to_index: MutMap<(Symbol, ProcLayout<'a>), ParamOffset>, /// the parameters of all functions in a single flat array. /// /// - the map above gives the index of the first parameter for the function @@ -294,10 +378,11 @@ impl<'a> BorrowInfState<'a> { &mut self, param_map: &mut ParamMap<'a>, symbol: Symbol, - layout: ProcLayout<'a>, + start: ParamOffset, + length: usize, ) { - let index: usize = param_map.declaration_to_index[&(symbol, layout)].into(); - let ps = &mut param_map.declarations[index..][..layout.arguments.len()]; + let index: usize = start.into(); + let ps = &mut param_map.declarations[index..][..length]; for p in ps.iter_mut() { if !p.borrow { @@ -762,7 +847,7 @@ impl<'a> BorrowInfState<'a> { &mut self, param_map: &mut ParamMap<'a>, proc: &Proc<'a>, - layout: ProcLayout<'a>, + param_offset: ParamOffset, ) { let old = self.param_set.clone(); @@ -774,7 +859,7 @@ impl<'a> BorrowInfState<'a> { self.owned.entry(proc.name).or_default(); self.collect_stmt(param_map, &proc.body); - self.update_param_map_declaration(param_map, proc.name, layout); + self.update_param_map_declaration(param_map, proc.name, param_offset, proc.args.len()); self.param_set = old; } @@ -868,3 +953,100 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ExpectTrue => arena.alloc_slice_copy(&[irrelevant]), } } + +fn make_successor_mapping<'a>( + arena: &'a Bump, + procs: &MutMap<(Symbol, ProcLayout<'_>), Proc<'a>>, +) -> MutMap> { + let mut result = MutMap::with_capacity_and_hasher(procs.len(), default_hasher()); + + for (key, proc) in procs { + let mut call_info = CallInfo { + keys: Vec::new_in(arena), + }; + call_info_stmt(arena, &proc.body, &mut call_info); + + let mut keys = call_info.keys; + keys.sort_unstable(); + keys.dedup(); + + result.insert(key.0, keys); + } + + result +} + +struct CallInfo<'a> { + // keys: MutSet<(Symbol, ProcLayout<'a>)>, + keys: Vec<'a, Symbol>, +} + +fn call_info_call<'a>(call: &crate::ir::Call<'a>, info: &mut CallInfo<'a>) { + use crate::ir::CallType::*; + + match call.call_type { + ByName { + name, + ret_layout, + arg_layouts, + .. + } => { + let proc_layout = crate::ir::ProcLayout { + arguments: arg_layouts, + result: ret_layout, + }; + + //let key = (name, proc_layout); + // info.keys.insert(key); + info.keys.push(name); + } + Foreign { .. } => {} + LowLevel { .. } => {} + HigherOrderLowLevel { .. } => {} + } +} + +fn call_info_stmt<'a>(arena: &'a Bump, stmt: &Stmt<'a>, info: &mut CallInfo<'a>) { + use Stmt::*; + + let mut stack = bumpalo::vec![ in arena; stmt ]; + + while let Some(stmt) = stack.pop() { + match stmt { + Join { + remainder: v, + body: b, + .. + } => { + stack.push(v); + stack.push(b); + } + Let(_, expr, _, cont) => { + if let Expr::Call(call) = expr { + call_info_call(call, info); + } + stack.push(cont); + } + Invoke { + call, pass, fail, .. + } => { + call_info_call(call, info); + stack.push(pass); + stack.push(fail); + } + Switch { + branches, + default_branch, + .. + } => { + stack.extend(branches.iter().map(|b| &b.2)); + stack.push(default_branch.1); + } + Refcounting(_, _) => unreachable!("these have not been introduced yet"), + + Ret(_) | Resume(_) | Jump(_, _) | RuntimeError(_) => { + // these are terminal, do nothing + } + } + } +}