use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt}; use crate::layout::Layout; use bumpalo::collections::Vec; use bumpalo::Bump; use roc_collections::all::{MutMap, MutSet}; use roc_module::low_level::LowLevel; use roc_module::symbol::Symbol; fn should_borrow_layout(layout: &Layout) -> bool { match layout { Layout::Closure(_, _, _) => false, _ => layout.is_refcounted(), } } pub fn infer_borrow<'a>( arena: &'a Bump, procs: &MutMap<(Symbol, Layout<'a>), Proc<'a>>, ) -> ParamMap<'a> { let mut param_map = ParamMap { items: MutMap::default(), }; for (key, proc) in procs { param_map.visit_proc(arena, proc, key.clone()); } let mut env = BorrowInfState { current_proc: Symbol::ATTR_ATTR, param_set: MutSet::default(), owned: MutMap::default(), modified: false, param_map, arena, }; // This is a fixed-point analysis // // all functions initiall own all their paramters // through a series of checks and heuristics, some arguments are set to borrowed // when that doesn't lead to conflicts the change is kept, otherwise it may be reverted // // when the signatures no longer change, the analysis stops and returns the signatures loop { // sort the symbols (roughly) in definition order. // TODO in the future I think we need to do this properly, and group // mutually recursive functions (or just make all their arguments owned) for (key, proc) in procs { env.collect_proc(proc, key.1.clone()); } if !env.modified { // if there were no modifications, we're done break; } else { // otherwise see if there are changes after another iteration env.modified = false; } } env.param_map } #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum Key<'a> { Declaration(Symbol, Layout<'a>), JoinPoint(JoinPointId), } #[derive(Debug, Clone, Default)] pub struct ParamMap<'a> { items: MutMap, &'a [Param<'a>]>, } impl<'a> IntoIterator for ParamMap<'a> { type Item = (Key<'a>, &'a [Param<'a>]); type IntoIter = , &'a [Param<'a>]> as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.items.into_iter() } } impl<'a> IntoIterator for &'a ParamMap<'a> { type Item = (&'a Key<'a>, &'a &'a [Param<'a>]); type IntoIter = <&'a std::collections::HashMap, &'a [Param<'a>]> as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.items.iter() } } impl<'a> ParamMap<'a> { pub fn get_symbol(&self, symbol: Symbol, layout: Layout<'a>) -> Option<&'a [Param<'a>]> { let key = Key::Declaration(symbol, layout); self.items.get(&key).copied() } pub fn get_join_point(&self, id: JoinPointId) -> &'a [Param<'a>] { let key = Key::JoinPoint(id); match self.items.get(&key) { Some(slice) => slice, None => unreachable!("join point not in param map: {:?}", id), } } } impl<'a> ParamMap<'a> { fn init_borrow_params(arena: &'a Bump, ps: &'a [Param<'a>]) -> &'a [Param<'a>] { Vec::from_iter_in( ps.iter().map(|p| Param { borrow: p.layout.is_refcounted(), layout: p.layout.clone(), symbol: p.symbol, }), arena, ) .into_bump_slice() } fn init_borrow_args(arena: &'a Bump, ps: &'a [(Layout<'a>, Symbol)]) -> &'a [Param<'a>] { Vec::from_iter_in( ps.iter().map(|(layout, symbol)| Param { borrow: should_borrow_layout(layout), layout: layout.clone(), symbol: *symbol, }), arena, ) .into_bump_slice() } fn init_borrow_args_always_owned( arena: &'a Bump, ps: &'a [(Layout<'a>, Symbol)], ) -> &'a [Param<'a>] { Vec::from_iter_in( ps.iter().map(|(layout, symbol)| Param { borrow: false, layout: layout.clone(), symbol: *symbol, }), arena, ) .into_bump_slice() } fn visit_proc(&mut self, arena: &'a Bump, proc: &Proc<'a>, key: (Symbol, Layout<'a>)) { if proc.must_own_arguments { self.visit_proc_always_owned(arena, proc, key); return; } let already_in_there = self.items.insert( Key::Declaration(proc.name, key.1), Self::init_borrow_args(arena, proc.args), ); debug_assert!(already_in_there.is_none()); self.visit_stmt(arena, proc.name, &proc.body); } fn visit_proc_always_owned( &mut self, arena: &'a Bump, proc: &Proc<'a>, key: (Symbol, Layout<'a>), ) { let already_in_there = self.items.insert( Key::Declaration(proc.name, key.1), Self::init_borrow_args_always_owned(arena, proc.args), ); debug_assert!(already_in_there.is_none()); self.visit_stmt(arena, proc.name, &proc.body); } fn visit_stmt(&mut self, arena: &'a Bump, _fnid: Symbol, stmt: &Stmt<'a>) { use Stmt::*; let mut stack = bumpalo::vec![ in arena; stmt ]; while let Some(stmt) = stack.pop() { match stmt { Join { id: j, parameters: xs, remainder: v, continuation: b, } => { let already_in_there = self .items .insert(Key::JoinPoint(*j), Self::init_borrow_params(arena, xs)); debug_assert!(already_in_there.is_none()); stack.push(v); stack.push(b); } Let(_, _, _, cont) => { stack.push(cont); } Invoke { pass, fail, .. } => { stack.push(pass); stack.push(fail); } Switch { branches, default_branch, .. } => { stack.extend(branches.iter().map(|b| &b.2)); stack.push(default_branch.1); } Refcounting(_, _) => unreachable!("these have not been introduced yet"), Ret(_) | Rethrow | Jump(_, _) | RuntimeError(_) => { // these are terminal, do nothing } } } } } // Apply the inferred borrow annotations stored in ParamMap to a block of mutually recursive procs struct BorrowInfState<'a> { current_proc: Symbol, param_set: MutSet, owned: MutMap>, modified: bool, param_map: ParamMap<'a>, arena: &'a Bump, } impl<'a> BorrowInfState<'a> { pub fn own_var(&mut self, x: Symbol) { let current = self.owned.get_mut(&self.current_proc).unwrap(); if current.contains(&x) { // do nothing } else { current.insert(x); self.modified = true; } } fn is_owned(&self, x: Symbol) -> bool { match self.owned.get(&self.current_proc) { None => unreachable!( "the current procedure symbol {:?} is not in the owned map", self.current_proc ), Some(set) => set.contains(&x), } } fn update_param_map(&mut self, k: Key<'a>) { let arena = self.arena; if let Some(ps) = self.param_map.items.get(&k) { let ps = Vec::from_iter_in( ps.iter().map(|p| { if !p.borrow { p.clone() } else if self.is_owned(p.symbol) { self.modified = true; let mut p = p.clone(); p.borrow = false; p } else { p.clone() } }), arena, ); self.param_map.items.insert(k, ps.into_bump_slice()); } } /// This looks at an application `f x1 x2 x3` /// If the parameter (based on the definition of `f`) is owned, /// then the argument must also be owned fn own_args_using_params(&mut self, xs: &[Symbol], ps: &[Param<'a>]) { debug_assert_eq!(xs.len(), ps.len()); for (x, p) in xs.iter().zip(ps.iter()) { if !p.borrow { self.own_var(*x); } } } /// This looks at an application `f x1 x2 x3` /// If the parameter (based on the definition of `f`) is owned, /// then the argument must also be owned fn own_args_using_bools(&mut self, xs: &[Symbol], ps: &[bool]) { debug_assert_eq!(xs.len(), ps.len()); for (x, borrow) in xs.iter().zip(ps.iter()) { if !borrow { self.own_var(*x); } } } fn own_arg(&mut self, x: Symbol) { self.own_var(x); } fn own_args(&mut self, xs: &[Symbol]) { for x in xs.iter() { self.own_arg(*x); } } /// For each xs[i], if xs[i] is owned, then mark ps[i] as owned. /// We use this action to preserve tail calls. That is, if we have /// a tail call `f xs`, if the i-th parameter is borrowed, but `xs[i]` is owned /// we would have to insert a `dec xs[i]` after `f xs` and consequently /// "break" the tail call. fn own_params_using_args(&mut self, xs: &[Symbol], ps: &[Param<'a>]) { debug_assert_eq!(xs.len(), ps.len()); for (x, p) in xs.iter().zip(ps.iter()) { if self.is_owned(*x) { self.own_var(p.symbol); } } } /// Mark `xs[i]` as owned if it is one of the parameters `ps`. /// We use this action to mark function parameters that are being "packed" inside constructors. /// This is a heuristic, and is not related with the effectiveness of the reset/reuse optimization. /// It is useful for code such as /// /// > def f (x y : obj) := /// > let z := ctor_1 x y; /// > ret z fn own_args_if_param(&mut self, xs: &[Symbol]) { for x in xs.iter() { // TODO may also be asking for the index here? see Lean if self.param_set.contains(x) { self.own_var(*x); } } } /// This looks at the assignement /// /// let z = e in ... /// /// and determines whether z and which of the symbols used in e /// must be taken as owned paramters fn collect_call(&mut self, z: Symbol, e: &crate::ir::Call<'a>) { use crate::ir::CallType::*; let crate::ir::Call { call_type, arguments, } = e; match call_type { ByName { name, full_layout, .. } => { // get the borrow signature of the applied function match self.param_map.get_symbol(*name, full_layout.clone()) { Some(ps) => { // the return value will be owned self.own_var(z); // if the function exects an owned argument (ps), the argument must be owned (args) debug_assert_eq!( arguments.len(), ps.len(), "{:?} has {} parameters, but was applied to {} arguments", name, ps.len(), arguments.len() ); self.own_args_using_params(arguments, ps); } None => { // this is really an indirect call, but the function was bound to a symbol // the return value will be owned self.own_var(z); // if the function exects an owned argument (ps), the argument must be owned (args) self.own_args(arguments); } } } ByPointer { .. } => { // the return value will be owned self.own_var(z); // if the function exects an owned argument (ps), the argument must be owned (args) self.own_args(arguments); } LowLevel { op } => { // very unsure what demand RunLowLevel should place upon its arguments self.own_var(z); let ps = lowlevel_borrow_signature(self.arena, *op); self.own_args_using_bools(arguments, ps); } Foreign { .. } => { // very unsure what demand ForeignCall should place upon its arguments self.own_var(z); let ps = foreign_borrow_signature(self.arena, arguments.len()); self.own_args_using_bools(arguments, ps); } } } fn collect_expr(&mut self, z: Symbol, e: &Expr<'a>) { use Expr::*; match e { Tag { arguments: xs, .. } | Struct(xs) | Array { elems: xs, .. } => { self.own_var(z); // if the used symbol is an argument to the current function, // the function must take it as an owned parameter self.own_args_if_param(xs); } Reset(x) => { self.own_var(z); self.own_var(*x); } Reuse { symbol: x, arguments: ys, .. } => { self.own_var(z); self.own_var(*x); self.own_args_if_param(ys); } EmptyArray => { self.own_var(z); } AccessAtIndex { structure: x, .. } => { // if the structure (record/tag/array) is owned, the extracted value is if self.is_owned(*x) { self.own_var(z); } // if the extracted value is owned, the structure must be too if self.is_owned(z) { self.own_var(*x); } } Call(call) => self.collect_call(z, call), Literal(_) | FunctionPointer(_, _) | RuntimeErrorFunction(_) => {} } } #[allow(clippy::many_single_char_names)] fn preserve_tail_call(&mut self, x: Symbol, v: &Expr<'a>, b: &Stmt<'a>) { match (v, b) { ( Expr::Call(crate::ir::Call { call_type: crate::ir::CallType::ByName { name: g, full_layout, .. }, arguments: ys, .. }), Stmt::Ret(z), ) | ( Expr::Call(crate::ir::Call { call_type: crate::ir::CallType::ByPointer { name: g, full_layout, .. }, arguments: ys, .. }), Stmt::Ret(z), ) => { if self.current_proc == *g && x == *z { // anonymous functions (for which the ps may not be known) // can never be tail-recursive, so this is fine if let Some(ps) = self.param_map.get_symbol(*g, full_layout.clone()) { self.own_params_using_args(ys, ps) } } } _ => {} } } fn update_param_set(&mut self, ps: &[Param<'a>]) { for p in ps.iter() { self.param_set.insert(p.symbol); } } fn update_param_set_symbols(&mut self, ps: &[Symbol]) { for p in ps.iter() { self.param_set.insert(*p); } } fn collect_stmt(&mut self, stmt: &Stmt<'a>) { use Stmt::*; match stmt { Join { id: j, parameters: ys, remainder: v, continuation: b, } => { let old = self.param_set.clone(); self.update_param_set(ys); self.collect_stmt(v); self.param_set = old; self.update_param_map(Key::JoinPoint(*j)); self.collect_stmt(b); } Let(x, Expr::FunctionPointer(fsymbol, layout), _, b) => { // ensure that the function pointed to is in the param map if let Some(params) = self.param_map.get_symbol(*fsymbol, layout.clone()) { self.param_map .items .insert(Key::Declaration(*x, layout.clone()), params); } self.collect_stmt(b); self.preserve_tail_call(*x, &Expr::FunctionPointer(*fsymbol, layout.clone()), b); } Let(x, v, _, b) => { self.collect_stmt(b); self.collect_expr(*x, v); self.preserve_tail_call(*x, v, b); } Invoke { symbol, call, layout: _, pass, fail, } => { self.collect_stmt(pass); self.collect_stmt(fail); self.collect_call(*symbol, call); // TODO how to preserve the tail call of an invoke? // self.preserve_tail_call(*x, v, b); } Jump(j, ys) => { let ps = self.param_map.get_join_point(*j); // for making sure the join point can reuse self.own_args_using_params(ys, ps); // for making sure the tail call is preserved self.own_params_using_args(ys, ps); } Switch { branches, default_branch, .. } => { for (_, _, b) in branches.iter() { self.collect_stmt(b); } self.collect_stmt(default_branch.1); } Refcounting(_, _) => unreachable!("these have not been introduced yet"), Ret(_) | RuntimeError(_) | Rethrow => { // these are terminal, do nothing } } } fn collect_proc(&mut self, proc: &Proc<'a>, layout: Layout<'a>) { let old = self.param_set.clone(); let ys = Vec::from_iter_in(proc.args.iter().map(|t| t.1), self.arena).into_bump_slice(); self.update_param_set_symbols(ys); self.current_proc = proc.name; // ensure that current_proc is in the owned map self.owned.entry(proc.name).or_default(); self.collect_stmt(&proc.body); self.update_param_map(Key::Declaration(proc.name, layout)); self.param_set = old; } } pub fn foreign_borrow_signature(arena: &Bump, arity: usize) -> &[bool] { // NOTE this means that Roc is responsible for cleaning up resources; // the host cannot (currently) take ownership let all = bumpalo::vec![in arena; true; arity]; all.into_bump_slice() } pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { use LowLevel::*; // TODO is true or false more efficient for non-refcounted layouts? let irrelevant = false; let owned = false; let borrowed = true; // Here we define the borrow signature of low-level operations // // - arguments with non-refcounted layouts (ints, floats) are `irrelevant` // - arguments that we may want to update destructively must be Owned // - other refcounted arguments are Borrowed match op { ListLen | StrIsEmpty | StrCountGraphemes => arena.alloc_slice_copy(&[borrowed]), ListSet => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]), ListSetInPlace => arena.alloc_slice_copy(&[owned, irrelevant, irrelevant]), ListGetUnsafe => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListConcat | StrConcat => arena.alloc_slice_copy(&[borrowed, borrowed]), StrSplit => arena.alloc_slice_copy(&[borrowed, borrowed]), ListSingle => arena.alloc_slice_copy(&[irrelevant]), ListRepeat => arena.alloc_slice_copy(&[irrelevant, borrowed]), ListReverse => arena.alloc_slice_copy(&[owned]), ListPrepend => arena.alloc_slice_copy(&[owned, owned]), StrJoinWith => arena.alloc_slice_copy(&[borrowed, borrowed]), ListJoin => arena.alloc_slice_copy(&[irrelevant]), ListMap | ListMapWithIndex => arena.alloc_slice_copy(&[owned, irrelevant]), ListMap2 => arena.alloc_slice_copy(&[owned, owned, irrelevant]), ListMap3 => arena.alloc_slice_copy(&[owned, owned, owned, irrelevant]), ListKeepIf | ListKeepOks | ListKeepErrs => arena.alloc_slice_copy(&[owned, borrowed]), ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]), ListWalkBackwards => arena.alloc_slice_copy(&[owned, irrelevant, owned]), ListSum => arena.alloc_slice_copy(&[borrowed]), // TODO when we have lists with capacity (if ever) // List.append should own its first argument ListAppend => arena.alloc_slice_copy(&[owned, owned]), Eq | NotEq => arena.alloc_slice_copy(&[borrowed, borrowed]), And | Or | NumAdd | NumAddWrap | NumAddChecked | NumSub | NumSubWrap | NumSubChecked | NumMul | NumMulWrap | NumMulChecked | NumGt | NumGte | NumLt | NumLte | NumCompare | NumDivUnchecked | NumRemUnchecked | NumPow | NumPowInt | NumBitwiseAnd | NumBitwiseXor | NumBitwiseOr | NumShiftLeftBy | NumShiftRightBy | NumShiftRightZfBy => { arena.alloc_slice_copy(&[irrelevant, irrelevant]) } NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumCeiling | NumFloor | NumToFloat | Not | NumIsFinite | NumAtan | NumAcos | NumAsin | NumIntCast => { arena.alloc_slice_copy(&[irrelevant]) } StrStartsWith | StrEndsWith => arena.alloc_slice_copy(&[owned, borrowed]), StrFromUtf8 => arena.alloc_slice_copy(&[owned]), StrToBytes => arena.alloc_slice_copy(&[owned]), StrFromInt | StrFromFloat => arena.alloc_slice_copy(&[irrelevant]), Hash => arena.alloc_slice_copy(&[borrowed, irrelevant]), DictSize => arena.alloc_slice_copy(&[borrowed]), DictEmpty => &[], DictInsert => arena.alloc_slice_copy(&[owned, owned, owned]), DictRemove => arena.alloc_slice_copy(&[owned, borrowed]), DictContains => arena.alloc_slice_copy(&[borrowed, borrowed]), DictGetUnsafe => arena.alloc_slice_copy(&[borrowed, borrowed]), DictKeys | DictValues => arena.alloc_slice_copy(&[borrowed]), DictUnion | DictDifference | DictIntersection => arena.alloc_slice_copy(&[owned, borrowed]), // borrow function argument so we don't have to worry about RC of the closure DictWalk => arena.alloc_slice_copy(&[owned, borrowed, owned]), SetFromList => arena.alloc_slice_copy(&[owned]), } }