diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index a905df57db..d27c9d0f67 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -156,6 +156,10 @@ impl<'a> ParamMap<'a> { Let(_, _, _, cont) => { stack.push(cont); } + Invoke { pass, fail, .. } => { + stack.push(pass); + stack.push(fail); + } Switch { branches, default_branch, @@ -295,6 +299,62 @@ impl<'a> BorrowInfState<'a> { /// /// and determines whether z and which of the symbols used in e /// must be taken as owned paramters + fn collect_call(&mut self, z: Symbol, e: &crate::ir::Call<'a>) { + use crate::ir::CallType::*; + + let crate::ir::Call { + call_type, + arguments, + } = e; + + match call_type { + ByName { + name, arg_layouts, .. + } + | ByPointer { + name, arg_layouts, .. + } => { + // get the borrow signature of the applied function + let ps = match self.param_map.get_symbol(*name) { + Some(slice) => slice, + None => Vec::from_iter_in( + arg_layouts.iter().cloned().map(|layout| Param { + symbol: Symbol::UNDERSCORE, + borrow: false, + layout, + }), + self.arena, + ) + .into_bump_slice(), + }; + + // the return value will be owned + self.own_var(z); + + // if the function exects an owned argument (ps), the argument must be owned (args) + self.own_args_using_params(arguments, ps); + } + + LowLevel { op } => { + // very unsure what demand RunLowLevel should place upon its arguments + self.own_var(z); + + let ps = lowlevel_borrow_signature(self.arena, *op); + + self.own_args_using_bools(arguments, ps); + } + + Foreign { .. } => { + // very unsure what demand ForeignCall should place upon its arguments + self.own_var(z); + + let ps = foreign_borrow_signature(self.arena, arguments.len()); + + self.own_args_using_bools(arguments, ps); + } + } + } + fn collect_expr(&mut self, z: Symbol, e: &Expr<'a>) { use Expr::*; @@ -334,59 +394,7 @@ impl<'a> BorrowInfState<'a> { } } - Call(crate::ir::Call { - call_type, - arguments, - }) => { - use crate::ir::CallType::*; - - match call_type { - ByName { - name, arg_layouts, .. - } - | ByPointer { - name, arg_layouts, .. - } => { - // get the borrow signature of the applied function - let ps = match self.param_map.get_symbol(*name) { - Some(slice) => slice, - None => Vec::from_iter_in( - arg_layouts.iter().cloned().map(|layout| Param { - symbol: Symbol::UNDERSCORE, - borrow: false, - layout, - }), - self.arena, - ) - .into_bump_slice(), - }; - - // the return value will be owned - self.own_var(z); - - // if the function exects an owned argument (ps), the argument must be owned (args) - self.own_args_using_params(arguments, ps); - } - - LowLevel { op } => { - // very unsure what demand RunLowLevel should place upon its arguments - self.own_var(z); - - let ps = lowlevel_borrow_signature(self.arena, *op); - - self.own_args_using_bools(arguments, ps); - } - - Foreign { .. } => { - // very unsure what demand ForeignCall should place upon its arguments - self.own_var(z); - - let ps = foreign_borrow_signature(self.arena, arguments.len()); - - self.own_args_using_bools(arguments, ps); - } - } - } + Call(call) => self.collect_call(z, call), Literal(_) | FunctionPointer(_, _) | RuntimeErrorFunction(_) => {} } @@ -462,11 +470,29 @@ impl<'a> BorrowInfState<'a> { self.collect_stmt(b); self.preserve_tail_call(*x, &Expr::FunctionPointer(*fsymbol, layout.clone()), b); } + Let(x, v, _, b) => { self.collect_stmt(b); self.collect_expr(*x, v); self.preserve_tail_call(*x, v, b); } + + Invoke { + symbol, + call, + layout, + pass, + fail, + } => { + self.collect_stmt(pass); + self.collect_stmt(fail); + + self.collect_call(*symbol, call); + + // TODO how to preserve the tail call of an invoke? + // self.preserve_tail_call(*x, v, b); + } + Jump(j, ys) => { let ps = self.param_map.get_join_point(*j); diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index 807ac801d9..8b805c5db9 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -31,6 +31,21 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet, MutSet) { bound_variables.insert(*symbol); stack.push(cont); } + + Invoke { + symbol, + call, + pass, + fail, + .. + } => { + occuring_variables_call(call, &mut result); + result.insert(*symbol); + bound_variables.insert(*symbol); + stack.push(pass); + stack.push(fail); + } + Ret(symbol) => { result.insert(*symbol); } @@ -77,6 +92,12 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet, MutSet) { (result, bound_variables) } +fn occuring_variables_call(call: &crate::ir::Call<'_>, result: &mut MutSet) { + // NOTE though the function name does occur, it is a static constant in the program + // for liveness, it should not be included here. + result.extend(call.arguments.iter().copied()); +} + pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet) { use Expr::*; @@ -88,11 +109,7 @@ pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet) { result.insert(*symbol); } - Call(crate::ir::Call { arguments, .. }) => { - // NOTE thouth the function name does occur, it is a static constant in the program - // for liveness, it should not be included here. - result.extend(arguments.iter().copied()); - } + Call(call) => occuring_variables_call(call, result), Tag { arguments, .. } | Struct(arguments) @@ -204,6 +221,11 @@ fn consume_expr(m: &VarMap, e: &Expr<'_>) -> bool { } } +fn consume_call(m: &VarMap, e: &crate::ir::Call<'_>) -> bool { + // variables bound by a call (or invoke) must always be consumed + true +} + impl<'a> Context<'a> { pub fn new(arena: &'a Bump, param_map: &'a ParamMap<'a>) -> Self { let mut vars = MutMap::default(); @@ -406,6 +428,75 @@ impl<'a> Context<'a> { b } + fn visit_call( + &self, + z: Symbol, + call_type: crate::ir::CallType<'a>, + arguments: &'a [Symbol], + l: Layout<'a>, + b: &'a Stmt<'a>, + b_live_vars: &LiveVarSet, + ) -> &'a Stmt<'a> { + use crate::ir::CallType::*; + + match &call_type { + LowLevel { op } => { + let ps = crate::borrow::lowlevel_borrow_signature(self.arena, *op); + let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars); + + let v = Expr::Call(crate::ir::Call { + call_type, + arguments, + }); + + &*self.arena.alloc(Stmt::Let(z, v, l, b)) + } + + Foreign { .. } => { + let ps = crate::borrow::foreign_borrow_signature(self.arena, arguments.len()); + let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars); + + let v = Expr::Call(crate::ir::Call { + call_type, + arguments, + }); + + &*self.arena.alloc(Stmt::Let(z, v, l, b)) + } + + ByName { + name, arg_layouts, .. + } + | ByPointer { + name, arg_layouts, .. + } => { + // get the borrow signature + let ps = match self.param_map.get_symbol(*name) { + Some(slice) => slice, + None => Vec::from_iter_in( + arg_layouts.iter().cloned().map(|layout| Param { + symbol: Symbol::UNDERSCORE, + borrow: false, + layout, + }), + self.arena, + ) + .into_bump_slice(), + }; + + let v = Expr::Call(crate::ir::Call { + call_type, + arguments, + }); + + let b = self.add_dec_after_application(arguments, ps, b, b_live_vars); + let b = self.arena.alloc(Stmt::Let(z, v, l, b)); + + self.add_inc_before(arguments, ps, b, b_live_vars) + } + } + } + #[allow(clippy::many_single_char_names)] fn visit_variable_declaration( &self, @@ -442,54 +533,9 @@ impl<'a> Context<'a> { } Call(crate::ir::Call { - ref call_type, + call_type, arguments, - }) => { - use crate::ir::CallType::*; - - match &call_type { - LowLevel { op } => { - let ps = crate::borrow::lowlevel_borrow_signature(self.arena, *op); - let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars); - - &*self.arena.alloc(Stmt::Let(z, v, l, b)) - } - - Foreign { .. } => { - let ps = - crate::borrow::foreign_borrow_signature(self.arena, arguments.len()); - let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars); - - &*self.arena.alloc(Stmt::Let(z, v, l, b)) - } - - ByName { - name, arg_layouts, .. - } - | ByPointer { - name, arg_layouts, .. - } => { - // get the borrow signature - let ps = match self.param_map.get_symbol(*name) { - Some(slice) => slice, - None => Vec::from_iter_in( - arg_layouts.iter().cloned().map(|layout| Param { - symbol: Symbol::UNDERSCORE, - borrow: false, - layout, - }), - self.arena, - ) - .into_bump_slice(), - }; - - let b = self.add_dec_after_application(arguments, ps, b, b_live_vars); - let b = self.arena.alloc(Stmt::Let(z, v, l, b)); - - self.add_inc_before(arguments, ps, b, b_live_vars) - } - } - } + }) => self.visit_call(z, call_type, arguments, l, b, b_live_vars), EmptyArray | FunctionPointer(_, _) @@ -505,12 +551,23 @@ impl<'a> Context<'a> { (new_b, live_vars) } + fn update_var_info_invoke( + &self, + symbol: Symbol, + layout: &Layout<'a>, + call: &crate::ir::Call<'a>, + ) -> Self { + // is this value a constant? + // TODO do function pointers also fall into this category? + let persistent = call.arguments.is_empty(); + + // must this value be consumed? + let consume = consume_call(&self.vars, call); + + self.update_var_info_help(symbol, layout, persistent, consume) + } + fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self { - let mut ctx = self.clone(); - - // can this type be reference-counted at runtime? - let reference = layout.contains_refcounted(); - // is this value a constant? // TODO do function pointers also fall into this category? let persistent = match expr { @@ -519,7 +576,20 @@ impl<'a> Context<'a> { }; // must this value be consumed? - let consume = consume_expr(&ctx.vars, expr); + let consume = consume_expr(&self.vars, expr); + + self.update_var_info_help(symbol, layout, persistent, consume) + } + + fn update_var_info_help( + &self, + symbol: Symbol, + layout: &Layout<'a>, + persistent: bool, + consume: bool, + ) -> Self { + // can this type be reference-counted at runtime? + let reference = layout.contains_refcounted(); let info = VarInfo { reference, @@ -527,6 +597,8 @@ impl<'a> Context<'a> { consume, }; + let mut ctx = self.clone(); + ctx.vars.insert(symbol, info); ctx @@ -634,6 +706,46 @@ impl<'a> Context<'a> { ) } + Invoke { + symbol, + call, + pass, + fail, + layout, + } => { + // TODO this combines parts of Let and Switch. Did this happen correctly? + let mut case_live_vars = collect_stmt(stmt, &self.jp_live_vars, MutSet::default()); + + case_live_vars.remove(symbol); + + let fail = { + // TODO should we use ctor info like Lean? + let ctx = self.clone(); + let (b, alt_live_vars) = ctx.visit_stmt(fail); + ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b) + }; + + case_live_vars.insert(*symbol); + + let pass = { + // TODO should we use ctor info like Lean? + let ctx = self.clone(); + let ctx = ctx.update_var_info_invoke(*symbol, layout, call); + let (b, alt_live_vars) = ctx.visit_stmt(pass); + ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b) + }; + + let invoke = Invoke { + symbol: *symbol, + call: call.clone(), + pass, + fail, + layout: layout.clone(), + }; + let stmt = self.arena.alloc(invoke); + + (stmt, case_live_vars) + } Join { id: j, parameters: _, @@ -765,6 +877,25 @@ pub fn collect_stmt( vars } + Invoke { + symbol, + call, + pass, + fail, + .. + } => { + vars = collect_stmt(pass, jp_live_vars, vars); + vars = collect_stmt(fail, jp_live_vars, vars); + + vars.remove(symbol); + + let mut result = MutSet::default(); + occuring_variables_call(call, &mut result); + + vars.extend(result); + + vars + } Ret(symbol) => { vars.insert(*symbol); vars diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index b8cddd053f..436fefe5f9 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -741,6 +741,13 @@ pub type Stores<'a> = &'a [(Symbol, Layout<'a>, Expr<'a>)]; #[derive(Clone, Debug, PartialEq)] pub enum Stmt<'a> { Let(Symbol, Expr<'a>, Layout<'a>, &'a Stmt<'a>), + Invoke { + symbol: Symbol, + call: Call<'a>, + layout: Layout<'a>, + pass: &'a Stmt<'a>, + fail: &'a Stmt<'a>, + }, Switch { /// This *must* stand for an integer, because Switch potentially compiles to a jump table. cond_symbol: Symbol, @@ -1121,6 +1128,11 @@ impl<'a> Stmt<'a> { .append(alloc.hardline()) .append(cont.to_doc(alloc)), + Invoke { symbol, .. } => alloc + .text("invoke ") + .append(symbol_to_doc(alloc, *symbol)) + .append(" = ?"), + Ret(symbol) => alloc .text("ret ") .append(symbol_to_doc(alloc, *symbol)) @@ -4365,6 +4377,33 @@ fn substitute_in_stmt_help<'a>( None } } + Invoke { + symbol, + call, + layout, + pass, + fail, + } => { + let opt_call = substitute_in_call(arena, call, subs); + let opt_pass = substitute_in_stmt_help(arena, pass, subs); + let opt_fail = substitute_in_stmt_help(arena, fail, subs); + + if opt_pass.is_some() || opt_fail.is_some() { + let pass = opt_pass.unwrap_or(pass); + let fail = opt_fail.unwrap_or_else(|| *fail); + let call = opt_call.unwrap_or_else(|| call.clone()); + + Some(arena.alloc(Invoke { + symbol: *symbol, + call, + layout: layout.clone(), + pass, + fail, + })) + } else { + None + } + } Join { id, parameters, @@ -4483,6 +4522,69 @@ fn substitute_in_stmt_help<'a>( } } +fn substitute_in_call<'a>( + arena: &'a Bump, + call: &'a Call<'a>, + subs: &MutMap, +) -> Option> { + let Call { + call_type, + arguments, + } = call; + + let opt_call_type = match call_type { + CallType::ByName { + name, + arg_layouts, + ret_layout, + full_layout, + } => substitute(subs, *name).map(|new| CallType::ByName { + name: new, + arg_layouts, + ret_layout: ret_layout.clone(), + full_layout: full_layout.clone(), + }), + CallType::ByPointer { + name, + arg_layouts, + ret_layout, + full_layout, + } => substitute(subs, *name).map(|new| CallType::ByPointer { + name: new, + arg_layouts, + ret_layout: ret_layout.clone(), + full_layout: full_layout.clone(), + }), + CallType::Foreign { .. } => None, + CallType::LowLevel { .. } => None, + }; + + let mut did_change = false; + let new_args = Vec::from_iter_in( + arguments.iter().map(|s| match substitute(subs, *s) { + None => *s, + Some(s) => { + did_change = true; + s + } + }), + arena, + ); + + if did_change || opt_call_type.is_some() { + let call_type = opt_call_type.unwrap_or_else(|| call_type.clone()); + + let arguments = new_args.into_bump_slice(); + + Some(self::Call { + call_type, + arguments, + }) + } else { + None + } +} + fn substitute_in_expr<'a>( arena: &'a Bump, expr: &'a Expr<'a>, @@ -4493,62 +4595,7 @@ fn substitute_in_expr<'a>( match expr { Literal(_) | FunctionPointer(_, _) | EmptyArray | RuntimeErrorFunction(_) => None, - Call(self::Call { - call_type, - arguments, - }) => { - let opt_call_type = match call_type { - CallType::ByName { - name, - arg_layouts, - ret_layout, - full_layout, - } => substitute(subs, *name).map(|new| CallType::ByName { - name: new, - arg_layouts, - ret_layout: ret_layout.clone(), - full_layout: full_layout.clone(), - }), - CallType::ByPointer { - name, - arg_layouts, - ret_layout, - full_layout, - } => substitute(subs, *name).map(|new| CallType::ByPointer { - name: new, - arg_layouts, - ret_layout: ret_layout.clone(), - full_layout: full_layout.clone(), - }), - CallType::Foreign { .. } => None, - CallType::LowLevel { .. } => None, - }; - - let mut did_change = false; - let new_args = Vec::from_iter_in( - arguments.iter().map(|s| match substitute(subs, *s) { - None => *s, - Some(s) => { - did_change = true; - s - } - }), - arena, - ); - - if did_change || opt_call_type.is_some() { - let call_type = opt_call_type.unwrap_or_else(|| call_type.clone()); - - let arguments = new_args.into_bump_slice(); - - Some(Expr::Call(self::Call { - call_type, - arguments, - })) - } else { - None - } - } + Call(call) => substitute_in_call(arena, call, subs).map(|new| Expr::Call(new)), Tag { tag_layout, diff --git a/compiler/mono/src/tail_recursion.rs b/compiler/mono/src/tail_recursion.rs index 264e672ed6..4998c5fd84 100644 --- a/compiler/mono/src/tail_recursion.rs +++ b/compiler/mono/src/tail_recursion.rs @@ -90,6 +90,27 @@ fn insert_jumps<'a>( Some(arena.alloc(jump)) } + Invoke { + symbol, + call: + crate::ir::Call { + call_type: CallType::ByName { name: fsym, .. }, + arguments, + .. + }, + fail, + pass: Stmt::Ret(rsym), + .. + } if needle == *fsym && symbol == rsym => { + debug_assert_eq!(fail, &&Stmt::Unreachable); + + // replace the call and return with a jump + + let jump = Stmt::Jump(goal_id, arguments); + + Some(arena.alloc(jump)) + } + Let(symbol, expr, layout, cont) => { let opt_cont = insert_jumps(arena, cont, goal_id, needle); @@ -101,6 +122,35 @@ fn insert_jumps<'a>( None } } + + Invoke { + symbol, + call, + fail, + pass, + layout, + } => { + let opt_pass = insert_jumps(arena, pass, goal_id, needle); + let opt_fail = insert_jumps(arena, fail, goal_id, needle); + + if opt_pass.is_some() || opt_fail.is_some() { + let pass = opt_pass.unwrap_or(pass); + let fail = opt_fail.unwrap_or(fail); + + let stmt = Invoke { + symbol: *symbol, + call: call.clone(), + layout: layout.clone(), + pass, + fail, + }; + + Some(arena.alloc(stmt)) + } else { + None + } + } + Join { id, parameters,