diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 9b7ac01c14..847501b1ec 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -774,6 +774,9 @@ define_builtins! { // a caller (wrapper) for comparison 21 GENERIC_COMPARE_REF: "#generic_compare_ref" + + // used to initialize paramters in borrow.rs + 22 EMPTY_PARAM: "#empty_param" } 1 NUM: "Num" => { 0 NUM_NUM: "Num" imported // the Num.Num type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 6b0035de47..7c8aabdb00 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -23,14 +23,17 @@ pub fn infer_borrow<'a>( // intern the layouts let mut declaration_to_index = MutMap::with_capacity_and_hasher(procs.len(), default_hasher()); - for (i, key) in procs.keys().enumerate() { + let mut i = 0; + for key in procs.keys() { declaration_to_index.insert(*key, DeclarationId(i)); + + i += key.1.arguments.len(); } let mut param_map = ParamMap { declaration_to_index, join_points: MutMap::default(), - declarations: bumpalo::vec![in arena; &[] as &[_]; procs.len()], + declarations: bumpalo::vec![in arena; Param::EMPTY; i], }; for (key, proc) in procs { @@ -42,7 +45,6 @@ pub fn infer_borrow<'a>( param_set: MutSet::default(), owned: MutMap::default(), modified: false, - param_map, arena, }; @@ -59,7 +61,7 @@ pub fn infer_borrow<'a>( // mutually recursive functions (or just make all their arguments owned) for (key, proc) in procs { - env.collect_proc(proc, key.1); + env.collect_proc(&mut param_map, proc, key.1); } if !env.modified { @@ -71,7 +73,7 @@ pub fn infer_borrow<'a>( } } - env.param_map + param_map } #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -85,17 +87,23 @@ impl From for usize { #[derive(Debug, Clone)] pub struct ParamMap<'a> { + /// Map a (Symbol, ProcLayout) pair to the starting index in the `declarations` array declaration_to_index: MutMap<(Symbol, ProcLayout<'a>), DeclarationId>, - // IDEA: flatten the declarations into just one flat array - declarations: Vec<'a, &'a [Param<'a>]>, + /// the parameters of all functions in a single flat array. + /// + /// - the map above gives the index of the first parameter for the function + /// - the length of the ProcLayout's argument field gives the total number of parameters + /// + /// These can be read by taking a slice into this array, and can also be updated in-place + declarations: Vec<'a, Param<'a>>, join_points: MutMap]>, } impl<'a> ParamMap<'a> { - pub fn get_symbol(&self, symbol: Symbol, layout: ProcLayout<'a>) -> Option<&'a [Param<'a>]> { + pub fn get_symbol(&self, symbol: Symbol, layout: ProcLayout<'a>) -> Option<&[Param<'a>]> { let index: usize = self.declaration_to_index[&(symbol, layout)].into(); - self.declarations.get(index).copied() + self.declarations.get(index..index + layout.arguments.len()) } pub fn get_join_point(&self, id: JoinPointId) -> &'a [Param<'a>] { match self.join_points.get(&id) { @@ -156,7 +164,14 @@ impl<'a> ParamMap<'a> { } let index: usize = self.declaration_to_index[&key].into(); - self.declarations[index] = Self::init_borrow_args(arena, proc.args); + + for (i, param) in Self::init_borrow_args(arena, proc.args) + .iter() + .copied() + .enumerate() + { + self.declarations[index + i] = param; + } self.visit_stmt(arena, proc.name, &proc.body); } @@ -168,7 +183,14 @@ impl<'a> ParamMap<'a> { key: (Symbol, ProcLayout<'a>), ) { let index: usize = self.declaration_to_index[&key].into(); - self.declarations[index] = Self::init_borrow_args_always_owned(arena, proc.args); + + for (i, param) in Self::init_borrow_args_always_owned(arena, proc.args) + .iter() + .copied() + .enumerate() + { + self.declarations[index + i] = param; + } self.visit_stmt(arena, proc.name, &proc.body); } @@ -224,7 +246,6 @@ struct BorrowInfState<'a> { param_set: MutSet, owned: MutMap>, modified: bool, - param_map: ParamMap<'a>, arena: &'a Bump, } @@ -269,17 +290,31 @@ impl<'a> BorrowInfState<'a> { new_ps.into_bump_slice() } - fn update_param_map_declaration(&mut self, symbol: Symbol, layout: ProcLayout<'a>) { - let index: usize = self.param_map.declaration_to_index[&(symbol, layout)].into(); + fn update_param_map_declaration( + &mut self, + param_map: &mut ParamMap<'a>, + symbol: Symbol, + layout: ProcLayout<'a>, + ) { + let index: usize = param_map.declaration_to_index[&(symbol, layout)].into(); + let ps = &mut param_map.declarations[index..][..layout.arguments.len()]; - let ps = self.param_map.declarations[index]; - self.param_map.declarations[index] = self.update_param_map_help(ps); + for p in ps.iter_mut() { + if !p.borrow { + // do nothing + } else if self.is_owned(p.symbol) { + self.modified = true; + p.borrow = false; + } else { + // do nothing + } + } } - fn update_param_map_join_point(&mut self, id: JoinPointId) { - let ps = self.param_map.join_points[&id]; + fn update_param_map_join_point(&mut self, param_map: &mut ParamMap<'a>, id: JoinPointId) { + let ps = param_map.join_points[&id]; let new_ps = self.update_param_map_help(ps); - self.param_map.join_points.insert(id, new_ps); + param_map.join_points.insert(id, new_ps); } /// This looks at an application `f x1 x2 x3` @@ -346,7 +381,7 @@ impl<'a> BorrowInfState<'a> { /// /// and determines whether z and which of the symbols used in e /// must be taken as owned parameters - fn collect_call(&mut self, z: Symbol, e: &crate::ir::Call<'a>) { + fn collect_call(&mut self, param_map: &mut ParamMap<'a>, z: Symbol, e: &crate::ir::Call<'a>) { use crate::ir::CallType::*; let crate::ir::Call { @@ -364,8 +399,7 @@ impl<'a> BorrowInfState<'a> { let top_level = ProcLayout::new(self.arena, arg_layouts, *ret_layout); // get the borrow signature of the applied function - let ps = self - .param_map + let ps = param_map .get_symbol(*name, top_level) .expect("function is defined"); @@ -381,6 +415,7 @@ impl<'a> BorrowInfState<'a> { ps.len(), arguments.len() ); + self.own_args_using_params(arguments, ps); } @@ -411,7 +446,7 @@ impl<'a> BorrowInfState<'a> { match op { ListMap | ListKeepIf | ListKeepOks | ListKeepErrs => { - match self.param_map.get_symbol(arguments[1], closure_layout) { + match param_map.get_symbol(arguments[1], closure_layout) { Some(function_ps) => { // own the list if the function wants to own the element if !function_ps[0].borrow { @@ -427,7 +462,7 @@ impl<'a> BorrowInfState<'a> { } } ListMapWithIndex => { - match self.param_map.get_symbol(arguments[1], closure_layout) { + match param_map.get_symbol(arguments[1], closure_layout) { Some(function_ps) => { // own the list if the function wants to own the element if !function_ps[1].borrow { @@ -442,7 +477,7 @@ impl<'a> BorrowInfState<'a> { None => unreachable!(), } } - ListMap2 => match self.param_map.get_symbol(arguments[2], closure_layout) { + ListMap2 => match param_map.get_symbol(arguments[2], closure_layout) { Some(function_ps) => { // own the lists if the function wants to own the element if !function_ps[0].borrow { @@ -460,7 +495,7 @@ impl<'a> BorrowInfState<'a> { } None => unreachable!(), }, - ListMap3 => match self.param_map.get_symbol(arguments[3], closure_layout) { + ListMap3 => match param_map.get_symbol(arguments[3], closure_layout) { Some(function_ps) => { // own the lists if the function wants to own the element if !function_ps[0].borrow { @@ -481,7 +516,7 @@ impl<'a> BorrowInfState<'a> { None => unreachable!(), }, ListSortWith => { - match self.param_map.get_symbol(arguments[1], closure_layout) { + match param_map.get_symbol(arguments[1], closure_layout) { Some(function_ps) => { // always own the input list self.own_var(arguments[0]); @@ -495,7 +530,7 @@ impl<'a> BorrowInfState<'a> { } } ListWalk | ListWalkUntil | ListWalkBackwards | DictWalk => { - match self.param_map.get_symbol(arguments[2], closure_layout) { + match param_map.get_symbol(arguments[2], closure_layout) { Some(function_ps) => { // own the data structure if the function wants to own the element if !function_ps[0].borrow { @@ -537,7 +572,7 @@ impl<'a> BorrowInfState<'a> { } } - fn collect_expr(&mut self, z: Symbol, e: &Expr<'a>) { + fn collect_expr(&mut self, param_map: &mut ParamMap<'a>, z: Symbol, e: &Expr<'a>) { use Expr::*; match e { @@ -565,7 +600,7 @@ impl<'a> BorrowInfState<'a> { self.own_var(z); } - Call(call) => self.collect_call(z, call), + Call(call) => self.collect_call(param_map, z, call), Literal(_) | RuntimeErrorFunction(_) => {} @@ -608,7 +643,13 @@ impl<'a> BorrowInfState<'a> { } #[allow(clippy::many_single_char_names)] - fn preserve_tail_call(&mut self, x: Symbol, v: &Expr<'a>, b: &Stmt<'a>) { + fn preserve_tail_call( + &mut self, + param_map: &mut ParamMap<'a>, + x: Symbol, + v: &Expr<'a>, + b: &Stmt<'a>, + ) { if let ( Expr::Call(crate::ir::Call { call_type: @@ -629,7 +670,7 @@ impl<'a> BorrowInfState<'a> { if self.current_proc == *g && x == *z { // anonymous functions (for which the ps may not be known) // can never be tail-recursive, so this is fine - if let Some(ps) = self.param_map.get_symbol(*g, top_level) { + if let Some(ps) = param_map.get_symbol(*g, top_level) { self.own_params_using_args(ys, ps) } } @@ -648,7 +689,7 @@ impl<'a> BorrowInfState<'a> { } } - fn collect_stmt(&mut self, stmt: &Stmt<'a>) { + fn collect_stmt(&mut self, param_map: &mut ParamMap<'a>, stmt: &Stmt<'a>) { use Stmt::*; match stmt { @@ -660,17 +701,17 @@ impl<'a> BorrowInfState<'a> { } => { let old = self.param_set.clone(); self.update_param_set(ys); - self.collect_stmt(v); + self.collect_stmt(param_map, v); self.param_set = old; - self.update_param_map_join_point(*j); + self.update_param_map_join_point(param_map, *j); - self.collect_stmt(b); + self.collect_stmt(param_map, b); } Let(x, v, _, b) => { - self.collect_stmt(b); - self.collect_expr(*x, v); - self.preserve_tail_call(*x, v, b); + self.collect_stmt(param_map, b); + self.collect_expr(param_map, *x, v); + self.preserve_tail_call(param_map, *x, v, b); } Invoke { @@ -681,17 +722,17 @@ impl<'a> BorrowInfState<'a> { fail, exception_id: _, } => { - self.collect_stmt(pass); - self.collect_stmt(fail); + self.collect_stmt(param_map, pass); + self.collect_stmt(param_map, fail); - self.collect_call(*symbol, call); + self.collect_call(param_map, *symbol, call); // TODO how to preserve the tail call of an invoke? // self.preserve_tail_call(*x, v, b); } Jump(j, ys) => { - let ps = self.param_map.get_join_point(*j); + let ps = param_map.get_join_point(*j); // for making sure the join point can reuse self.own_args_using_params(ys, ps); @@ -705,9 +746,9 @@ impl<'a> BorrowInfState<'a> { .. } => { for (_, _, b) in branches.iter() { - self.collect_stmt(b); + self.collect_stmt(param_map, b); } - self.collect_stmt(default_branch.1); + self.collect_stmt(param_map, default_branch.1); } Refcounting(_, _) => unreachable!("these have not been introduced yet"), @@ -717,7 +758,12 @@ impl<'a> BorrowInfState<'a> { } } - fn collect_proc(&mut self, proc: &Proc<'a>, layout: ProcLayout<'a>) { + fn collect_proc( + &mut self, + param_map: &mut ParamMap<'a>, + proc: &Proc<'a>, + layout: ProcLayout<'a>, + ) { let old = self.param_set.clone(); let ys = Vec::from_iter_in(proc.args.iter().map(|t| t.1), self.arena).into_bump_slice(); @@ -727,8 +773,8 @@ impl<'a> BorrowInfState<'a> { // ensure that current_proc is in the owned map self.owned.entry(proc.name).or_default(); - self.collect_stmt(&proc.body); - self.update_param_map_declaration(proc.name, layout); + self.collect_stmt(param_map, &proc.body); + self.update_param_map_declaration(param_map, proc.name, layout); self.param_set = old; }