#![allow(clippy::manual_map)] use crate::ir::{ Call, CallType, Expr, JoinPointId, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId, }; use crate::layout::{ InLayout, LambdaName, Layout, LayoutInterner, LayoutRepr, STLayoutInterner, TagIdIntType, UnionLayout, }; use bumpalo::collections::Vec; use bumpalo::Bump; use roc_collections::{MutMap, VecMap}; use roc_module::low_level::LowLevel; use roc_module::symbol::{IdentIds, ModuleId, Symbol}; pub struct Env<'a, 'i> { arena: &'a Bump, home: ModuleId, interner: &'i mut STLayoutInterner<'a>, ident_ids: &'i mut IdentIds, } impl<'a, 'i> Env<'a, 'i> { fn unique_symbol(&mut self) -> Symbol { let ident_id = self.ident_ids.gen_unique(); Symbol::new(self.home, ident_id) } fn named_unique_symbol(&mut self, name: &str) -> Symbol { let ident_id = self.ident_ids.add_str(name); Symbol::new(self.home, ident_id) } } pub fn apply_trmc<'a, 'i>( arena: &'a Bump, interner: &'i mut STLayoutInterner<'a>, home: ModuleId, ident_ids: &'i mut IdentIds, procs: &mut MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>, ) { let mut env = Env { arena, interner, home, ident_ids, }; let env = &mut env; for proc in procs.values_mut() { use self::SelfRecursive::*; if let SelfRecursive(id) = proc.is_self_recursive { let trmc_candidate_symbols = trmc_candidates(env.interner, proc); if !trmc_candidate_symbols.is_empty() { let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc, trmc_candidate_symbols); *proc = new_proc; } else { let mut args = Vec::with_capacity_in(proc.args.len(), arena); let mut proc_args = Vec::with_capacity_in(proc.args.len(), arena); for (layout, symbol) in proc.args { let new = env.unique_symbol(); args.push((*layout, *symbol, new)); proc_args.push((*layout, new)); } let transformed = crate::tail_recursion::make_tail_recursive( arena, id, proc.name, proc.body.clone(), args.into_bump_slice(), proc.ret_layout, ); if let Some(with_tco) = transformed { proc.body = with_tco; proc.args = proc_args.into_bump_slice(); } } } } } /// Make tail calls into loops (using join points) /// /// e.g. /// /// > factorial n accum = if n == 1 then accum else factorial (n - 1) (n * accum) /// /// becomes /// /// ```elm /// factorial n1 accum1 = /// let joinpoint j n accum = /// if n == 1 then /// accum /// else /// jump j (n - 1) (n * accum) /// /// in /// jump j n1 accum1 /// ``` /// /// This will effectively compile into a loop in llvm, and /// won't grow the call stack for each iteration fn make_tail_recursive<'a>( arena: &'a Bump, id: JoinPointId, needle: LambdaName, stmt: Stmt<'a>, args: &'a [(InLayout<'a>, Symbol, Symbol)], ret_layout: InLayout<'a>, ) -> Option> { let allocated = arena.alloc(stmt); let new_stmt = insert_jumps(arena, allocated, id, needle, args, ret_layout)?; // if we did not early-return, jumps were inserted, we must now add a join point let params = Vec::from_iter_in( args.iter().map(|(layout, symbol, _)| Param { symbol: *symbol, layout: *layout, }), arena, ) .into_bump_slice(); // TODO could this be &[]? let args = Vec::from_iter_in(args.iter().map(|t| t.2), arena).into_bump_slice(); let jump = arena.alloc(Stmt::Jump(id, args)); let join = Stmt::Join { id, remainder: jump, parameters: params, body: new_stmt, }; Some(join) } fn insert_jumps<'a>( arena: &'a Bump, stmt: &'a Stmt<'a>, goal_id: JoinPointId, needle: LambdaName, needle_arguments: &'a [(InLayout<'a>, Symbol, Symbol)], needle_result: InLayout<'a>, ) -> Option<&'a Stmt<'a>> { use Stmt::*; // to insert a tail-call, it must not just be a call to the function itself, but it must also // have the same layout. In particular when lambda sets get involved, a self-recursive call may // have a different type and should not be converted to a jump! let is_equal_function = |function_name: LambdaName, arguments: &[_], result| { let it = needle_arguments.iter().map(|t| &t.0); needle == function_name && it.eq(arguments.iter()) && needle_result == result }; match stmt { Let( symbol, Expr::Call(crate::ir::Call { call_type: CallType::ByName { name: fsym, ret_layout, arg_layouts, .. }, arguments, }), _, Stmt::Ret(rsym), ) if symbol == rsym && is_equal_function(*fsym, arg_layouts, *ret_layout) => { // replace the call and return with a jump let jump = Stmt::Jump(goal_id, arguments); Some(arena.alloc(jump)) } Let(symbol, expr, layout, cont) => { let opt_cont = insert_jumps( arena, cont, goal_id, needle, needle_arguments, needle_result, ); if opt_cont.is_some() { let cont = opt_cont.unwrap_or(cont); Some(arena.alloc(Let(*symbol, expr.clone(), *layout, cont))) } else { None } } Join { id, parameters, remainder, body: continuation, } => { let opt_remainder = insert_jumps( arena, remainder, goal_id, needle, needle_arguments, needle_result, ); let opt_continuation = insert_jumps( arena, continuation, goal_id, needle, needle_arguments, needle_result, ); if opt_remainder.is_some() || opt_continuation.is_some() { let remainder = opt_remainder.unwrap_or(remainder); let continuation = opt_continuation.unwrap_or(*continuation); Some(arena.alloc(Join { id: *id, parameters, remainder, body: continuation, })) } else { None } } Switch { cond_symbol, cond_layout, branches, default_branch, ret_layout, } => { let opt_default = insert_jumps( arena, default_branch.1, goal_id, needle, needle_arguments, needle_result, ); let mut did_change = false; let opt_branches = Vec::from_iter_in( branches.iter().map(|(label, info, branch)| { match insert_jumps( arena, branch, goal_id, needle, needle_arguments, needle_result, ) { None => None, Some(branch) => { did_change = true; Some((*label, info.clone(), branch.clone())) } } }), arena, ); if opt_default.is_some() || did_change { let default_branch = ( default_branch.0.clone(), opt_default.unwrap_or(default_branch.1), ); let branches = if did_change { let new = Vec::from_iter_in( opt_branches.into_iter().zip(branches.iter()).map( |(opt_branch, branch)| match opt_branch { None => branch.clone(), Some(new_branch) => new_branch, }, ), arena, ); new.into_bump_slice() } else { branches }; Some(arena.alloc(Switch { cond_symbol: *cond_symbol, cond_layout: *cond_layout, default_branch, branches, ret_layout: *ret_layout, })) } else { None } } Refcounting(modify, cont) => { match insert_jumps( arena, cont, goal_id, needle, needle_arguments, needle_result, ) { Some(cont) => Some(arena.alloc(Refcounting(*modify, cont))), None => None, } } Dbg { source_location, source, symbol, variable, remainder, } => match insert_jumps( arena, remainder, goal_id, needle, needle_arguments, needle_result, ) { Some(cont) => Some(arena.alloc(Dbg { source_location, source, symbol: *symbol, variable: *variable, remainder: cont, })), None => None, }, Expect { condition, region, lookups, variables, remainder, } => match insert_jumps( arena, remainder, goal_id, needle, needle_arguments, needle_result, ) { Some(cont) => Some(arena.alloc(Expect { condition: *condition, region: *region, lookups, variables, remainder: cont, })), None => None, }, ExpectFx { condition, region, lookups, variables, remainder, } => match insert_jumps( arena, remainder, goal_id, needle, needle_arguments, needle_result, ) { Some(cont) => Some(arena.alloc(ExpectFx { condition: *condition, region: *region, lookups, variables, remainder: cont, })), None => None, }, Ret(_) => None, Jump(_, _) => None, Crash(..) => None, } } #[derive(Debug, Default)] struct TrmcCandidateSet { interner: arrayvec::ArrayVec, confirmed: u64, active: u64, invalid: u64, } impl TrmcCandidateSet { fn confirmed(&self) -> impl Iterator + '_ { self.interner .iter() .enumerate() .filter_map(|(i, s)| (self.confirmed & (1 << i) != 0).then_some(*s)) } fn active(&self) -> impl Iterator + '_ { self.interner .iter() .enumerate() .filter_map(|(i, s)| (self.active & (1 << i) != 0).then_some(*s)) } fn position(&self, symbol: Symbol) -> Option { self.interner.iter().position(|s| *s == symbol) } fn insert(&mut self, symbol: Symbol) { // there really is no way it could have been inserted already debug_assert!(self.position(symbol).is_none()); let index = self.interner.len(); self.interner.push(symbol); self.active |= 1 << index; } fn retain(&mut self, keep: F) where F: Fn(&Symbol) -> bool, { for (i, s) in self.interner.iter().enumerate() { if !keep(s) { let mask = 1 << i; self.active &= !mask; self.confirmed &= !mask; self.invalid |= mask; } } } fn confirm(&mut self, symbol: Symbol) { match self.position(symbol) { None => debug_assert_eq!(0, 1, "confirm of invalid symbol"), Some(index) => { let mask = 1 << index; debug_assert_eq!(self.invalid & mask, 0); debug_assert_ne!(self.active & mask, 0); self.active &= !mask; self.confirmed |= mask; } } } fn is_empty(&self) -> bool { self.confirmed == 0 } } fn trmc_candidates<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> TrmcCandidateSet where I: LayoutInterner<'a>, { // it must be a self-recursive function if !matches!( proc.is_self_recursive, crate::ir::SelfRecursive::SelfRecursive(_) ) { return TrmcCandidateSet::default(); } // and return a recursive tag union if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive()) { return TrmcCandidateSet::default(); } let mut candidate_set = TrmcCandidateSet::default(); trmc_candidates_help(proc.name, &proc.body, &mut candidate_set); candidate_set } fn trmc_candidates_help( function_name: LambdaName, stmt: &'_ Stmt<'_>, candidates: &mut TrmcCandidateSet, ) { // if this stmt is the literal tail tag application and return, then this is a TRMC opportunity if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) { // the tag application must directly use the result of the recursive call let recursive_call = candidates .active() .find(|call| cons_info.arguments.contains(call)); // if we find a usage, this is a confirmed TRMC call if let Some(recursive_call) = recursive_call { candidates.confirm(recursive_call); return; } } // if the stmt uses the active recursive call, that invalidates the recursive call for this branch candidates.retain(|recursive_call| !stmt_contains_symbol_nonrec(stmt, *recursive_call)); match stmt { Stmt::Let(symbol, expr, _, next) => { // find a new recursive call if we currently have none // that means we generally pick the first recursive call we find if TrmcEnv::is_recursive_expr(expr, function_name).is_some() { candidates.insert(*symbol); } trmc_candidates_help(function_name, next, candidates) } Stmt::Switch { branches, default_branch, .. } => { let it = branches .iter() .map(|(_, _, stmt)| stmt) .chain([default_branch.1]); for next in it { trmc_candidates_help(function_name, next, candidates); } } Stmt::Refcounting(_, next) => trmc_candidates_help(function_name, next, candidates), Stmt::Expect { remainder, .. } | Stmt::ExpectFx { remainder, .. } | Stmt::Dbg { remainder, .. } => trmc_candidates_help(function_name, remainder, candidates), Stmt::Join { body, remainder, .. } => { trmc_candidates_help(function_name, body, candidates); trmc_candidates_help(function_name, remainder, candidates); } Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ } } } // TRMC (tail recursion modulo constructor) is an optimization for some recursive functions that return a recursive data type. The most basic example is a repeat function on linked lists: // // ```roc // LinkedList a : [ Nil, Cons a (LinkedList a) ] // // repeat : a, Nat -> LinkedList a // repeat = \element, n -> // when n is // 0 -> Nil // _ -> Cons element (repeat element (n - 1)) // ``` // // This function is recursive, but cannot use standard tail-call elimintation, because the recursive call is not in tail position (i.e. the last thing happening before a return). Rather the recursive call is an argument to a constructor of the recursive output type. This means that `repeat n` will creat `n` stack frames. For big inputs, a stack overflow is inevitable. // // But there is a trick: TRMC. Using TRMC and join points, we are able to convert this function into a loop, which uses only one stack frame for the whole process. // // ```pseudo-roc // repeat : a, Nat -> LinkedList a // repeat = \initialElement, initialN -> // joinpoint trmc = \element, n, hole, head -> // when n is // 0 -> // # write the value `Nil` into the hole // *hole = Nil // # dereference (load from) the pointer to the first element // *head // // _ -> // *hole = Cons element NULL // newHole = &hole.Cons.1 // jump trmc element (n - 1) newHole head // in // # creates a stack allocation, gives a pointer to that stack allocation // initial : Ptr (LinkedList a) = #alloca NULL // jump trmc initialElement initialN initial initial // ``` // // The functionality here figures out whether this transformation can be applied in valid way, and then performs the transformation. #[derive(Clone)] pub(crate) struct TrmcEnv<'a> { lambda_name: LambdaName<'a>, /// Current hole to fill hole_symbol: Symbol, /// Pointer to the first constructor ("the head of the list") head_symbol: Symbol, joinpoint_id: JoinPointId, return_layout: InLayout<'a>, ptr_return_layout: InLayout<'a>, trmc_calls: VecMap>>, } #[derive(Debug)] struct ConstructorInfo<'a> { tag_layout: UnionLayout<'a>, tag_id: TagIdIntType, arguments: &'a [Symbol], } impl<'a> TrmcEnv<'a> { #[inline(always)] fn is_terminal_constructor(stmt: &Stmt<'a>) -> Option> { match stmt { Stmt::Let(s1, expr, _layout, Stmt::Ret(s2)) if s1 == s2 => { Self::get_contructor_info(expr) } _ => None, } } fn get_contructor_info(expr: &Expr<'a>) -> Option> { if let Expr::Tag { tag_layout, tag_id, arguments, reuse, } = expr { debug_assert!(reuse.is_none()); let info = ConstructorInfo { tag_layout: *tag_layout, tag_id: *tag_id, arguments, }; Some(info) } else { None } } fn is_recursive_expr(expr: &Expr<'a>, lambda_name: LambdaName<'_>) -> Option> { if let Expr::Call(call) = expr { Self::is_recursive_call(call, lambda_name).then_some(call.clone()) } else { None } } fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool { match call.call_type { CallType::ByName { name, .. } => { // because we do not allow polymorphic recursion, this is the only constraint name == lambda_name } CallType::ByPointer { .. } => false, CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => { false } } } fn is_tail_recursive_call( lambda_name: LambdaName, symbol: Symbol, expr: &Expr<'a>, next: &Stmt<'a>, ) -> Option> { match next { Stmt::Ret(s) if *s == symbol => Self::is_recursive_expr(expr, lambda_name), _ => None, } } fn ptr_write( env: &mut Env<'a, '_>, ptr: Symbol, value: Symbol, next: &'a Stmt<'a>, ) -> Stmt<'a> { let ptr_write = Call { call_type: crate::ir::CallType::LowLevel { op: LowLevel::PtrStore, update_mode: UpdateModeId::BACKEND_DUMMY, }, arguments: env.arena.alloc([ptr, value]), }; Stmt::Let( env.named_unique_symbol("_ptr_write_unit"), Expr::Call(ptr_write), Layout::UNIT, next, ) } fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>, trmc_calls: TrmcCandidateSet) -> Proc<'a> { let arena = env.arena; let return_layout = proc.ret_layout; let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 2, env.arena); let mut new_proc_arguments = Vec::with_capacity_in(proc.args.len(), env.arena); let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 2, env.arena); for (i, (layout, old_symbol)) in proc.args.iter().enumerate() { let symbol = env.named_unique_symbol(&format!("arg_{i}")); new_proc_arguments.push((*layout, symbol)); jump_arguments.push(symbol); let param = Param { symbol: *old_symbol, layout: *layout, }; joinpoint_parameters.push(param); } // the root of the recursive structure that we'll be building let initial_ptr_symbol = env.named_unique_symbol("initial"); jump_arguments.push(initial_ptr_symbol); jump_arguments.push(initial_ptr_symbol); let null_symbol = env.named_unique_symbol("null"); let let_null = |next| Stmt::Let(null_symbol, Expr::NullPointer, return_layout, next); let ptr_return_layout = env .interner .insert_direct_no_semantic(LayoutRepr::Ptr(return_layout)); let ptr_null = Expr::Alloca { initializer: Some(null_symbol), element_layout: return_layout, }; let let_ptr = |next| Stmt::Let(initial_ptr_symbol, ptr_null, ptr_return_layout, next); let joinpoint_id = JoinPointId(env.named_unique_symbol("trmc")); let hole_symbol = env.named_unique_symbol("hole"); let head_symbol = env.named_unique_symbol("head"); let jump_stmt = Stmt::Jump(joinpoint_id, jump_arguments.into_bump_slice()); let trmc_calls = trmc_calls.confirmed().map(|s| (s, None)).collect(); let mut this = Self { lambda_name: proc.name, hole_symbol, head_symbol, joinpoint_id, return_layout, ptr_return_layout, trmc_calls, }; let param = Param { symbol: hole_symbol, layout: ptr_return_layout, }; joinpoint_parameters.push(param); let param = Param { symbol: head_symbol, layout: ptr_return_layout, }; joinpoint_parameters.push(param); let joinpoint = Stmt::Join { id: joinpoint_id, parameters: joinpoint_parameters.into_bump_slice(), body: arena.alloc(this.walk_stmt(env, &proc.body)), remainder: arena.alloc(jump_stmt), }; let body = let_null(arena.alloc( // let_ptr(arena.alloc( // joinpoint, )), )); #[cfg(debug_assertions)] env.home.register_debug_idents(env.ident_ids); Proc { name: proc.name, args: new_proc_arguments.into_bump_slice(), body, closure_data_layout: proc.closure_data_layout, ret_layout: proc.ret_layout, is_self_recursive: SelfRecursive::NotSelfRecursive, is_erased: proc.is_erased, } } fn walk_stmt(&mut self, env: &mut Env<'a, '_>, stmt: &Stmt<'a>) -> Stmt<'a> { let arena = env.arena; match stmt { Stmt::Let(symbol, expr, layout, next) => { // if this is a TRMC call, remember what the call looks like, so we can turn it // into a jump later. The call is then removed from the Stmt if let Some(opt_call) = self.trmc_calls.get_mut(symbol) { debug_assert!( opt_call.is_none(), "didn't expect to visit call again since symbols are unique" ); let call = match expr { Expr::Call(call) => call, _ => unreachable!(), }; *opt_call = Some(call.clone()); return self.walk_stmt(env, next); } if let Some(call) = Self::is_tail_recursive_call(self.lambda_name, *symbol, expr, next) { // turn the call into a jump. Just re-use the existing hole let mut arguments = Vec::new_in(arena); arguments.extend(call.arguments); arguments.push(self.hole_symbol); arguments.push(self.head_symbol); let jump = Stmt::Jump(self.joinpoint_id, arguments.into_bump_slice()); return jump; } if let Some(cons_info) = Self::is_terminal_constructor(stmt) { // figure out which TRMC call to use here. We pick the first one that works let opt_recursive_call = cons_info.arguments.iter().find_map(|arg| { self.trmc_calls .get(arg) .and_then(|x| x.as_ref()) .map(|x| (arg, x)) }); match opt_recursive_call { None => { // this control flow path did not encounter a recursive call. Just // write the end result into the hole and we're done. let define_tag = |next| Stmt::Let(*symbol, expr.clone(), *layout, next); let output = define_tag(arena.alloc( // self.non_trmc_return(env, *symbol), )); return output; } Some((call_symbol, call)) => { // we did encounter a recursive call, and can perform TRMC in this // branch. let opt_recursive_field_index = cons_info.arguments.iter().position(|s| *s == *call_symbol); let recursive_field_index = match opt_recursive_field_index { None => { let next = self.walk_stmt(env, next); return Stmt::Let( *symbol, expr.clone(), *layout, arena.alloc(next), ); } Some(v) => v, }; let tag_arg_null_symbol = env.named_unique_symbol("tag_arg_null"); let let_tag_arg_null = |next| { Stmt::Let( tag_arg_null_symbol, Expr::NullPointer, self.return_layout, next, ) }; let mut arguments = Vec::from_iter_in(cons_info.arguments.iter().copied(), env.arena); arguments[recursive_field_index] = tag_arg_null_symbol; let tag_expr = Expr::Tag { tag_layout: cons_info.tag_layout, tag_id: cons_info.tag_id, arguments: arguments.into_bump_slice(), reuse: None, }; let indices = arena .alloc([cons_info.tag_id as u64, recursive_field_index as u64]); let let_tag = |next| Stmt::Let(*symbol, tag_expr, *layout, next); let get_reference_expr = Expr::GetElementPointer { structure: *symbol, union_layout: cons_info.tag_layout, indices, }; let new_hole_symbol = env.named_unique_symbol("newHole"); let let_new_hole = |next| { Stmt::Let( new_hole_symbol, get_reference_expr, self.ptr_return_layout, next, ) }; let mut jump_arguments = Vec::from_iter_in(call.arguments.iter().copied(), env.arena); jump_arguments.push(new_hole_symbol); jump_arguments.push(self.head_symbol); let jump = Stmt::Jump(self.joinpoint_id, jump_arguments.into_bump_slice()); let output = let_tag_arg_null(arena.alloc( // let_tag(arena.alloc( // let_new_hole(arena.alloc( // Self::ptr_write( env, self.hole_symbol, *symbol, arena.alloc(jump), ), )), )), )); return output; } } } let next = self.walk_stmt(env, next); Stmt::Let(*symbol, expr.clone(), *layout, arena.alloc(next)) } Stmt::Switch { cond_symbol, cond_layout, branches, default_branch, ret_layout, } => { let mut new_branches = Vec::with_capacity_in(branches.len(), arena); for (id, info, stmt) in branches.iter() { let new_stmt = self.walk_stmt(env, stmt); new_branches.push((*id, info.clone(), new_stmt)); } let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1)); Stmt::Switch { cond_symbol: *cond_symbol, cond_layout: *cond_layout, branches: arena.alloc(new_branches.into_bump_slice()), default_branch: (default_branch.0.clone(), new_default_branch), ret_layout: *ret_layout, } } Stmt::Ret(symbol) => { // write the symbol we're supposed to return into the hole // then read initial_symbol and return its contents self.non_trmc_return(env, *symbol) } Stmt::Refcounting(op, next) => { let new_next = self.walk_stmt(env, next); Stmt::Refcounting(*op, arena.alloc(new_next)) } Stmt::Expect { condition, region, lookups, variables, remainder, } => Stmt::Expect { condition: *condition, region: *region, lookups, variables, remainder: arena.alloc(self.walk_stmt(env, remainder)), }, Stmt::ExpectFx { condition, region, lookups, variables, remainder, } => Stmt::Expect { condition: *condition, region: *region, lookups, variables, remainder: arena.alloc(self.walk_stmt(env, remainder)), }, Stmt::Dbg { source_location, source, symbol, variable, remainder, } => Stmt::Dbg { source_location, source, symbol: *symbol, variable: *variable, remainder: arena.alloc(self.walk_stmt(env, remainder)), }, Stmt::Join { id, parameters, body, remainder, } => { let new_body = self.walk_stmt(env, body); let new_remainder = self.walk_stmt(env, remainder); Stmt::Join { id: *id, parameters, body: arena.alloc(new_body), remainder: arena.alloc(new_remainder), } } Stmt::Jump(id, arguments) => Stmt::Jump(*id, arguments), Stmt::Crash(symbol, crash_tag) => Stmt::Crash(*symbol, *crash_tag), } } fn non_trmc_return(&mut self, env: &mut Env<'a, '_>, value_symbol: Symbol) -> Stmt<'a> { let arena = env.arena; let layout = self.return_layout; let final_symbol = env.named_unique_symbol("final"); let call = Call { call_type: CallType::LowLevel { op: LowLevel::PtrLoad, update_mode: UpdateModeId::BACKEND_DUMMY, }, arguments: &*arena.alloc([self.head_symbol]), }; let ptr_load = |next| Stmt::Let(final_symbol, Expr::Call(call), layout, next); Self::ptr_write( env, self.hole_symbol, value_symbol, arena.alloc( // ptr_load(arena.alloc(Stmt::Ret(final_symbol))), ), ) } } fn expr_contains_symbol(expr: &Expr, needle: Symbol) -> bool { match expr { Expr::Literal(_) => false, Expr::Call(call) => call.arguments.contains(&needle), Expr::Tag { arguments, reuse, .. } => match reuse { None => arguments.contains(&needle), Some(ru) => ru.symbol == needle || arguments.contains(&needle), }, Expr::Struct(fields) => fields.contains(&needle), Expr::NullPointer | Expr::FunctionPointer { .. } => false, Expr::StructAtIndex { structure, .. } | Expr::GetTagId { structure, .. } | Expr::UnionAtIndex { structure, .. } | Expr::GetElementPointer { structure, .. } => needle == *structure, Expr::Array { elems, .. } => elems.iter().any(|element| match element { crate::ir::ListLiteralElement::Literal(_) => false, crate::ir::ListLiteralElement::Symbol(symbol) => needle == *symbol, }), Expr::EmptyArray => false, Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => needle == *symbol, Expr::RuntimeErrorFunction(_) => false, Expr::ErasedMake { value, callee } => { value.map(|v| v == needle).unwrap_or(false) || needle == *callee } Expr::ErasedLoad { symbol, field: _ } => needle == *symbol, Expr::Alloca { initializer, .. } => &Some(needle) == initializer, } } fn stmt_contains_symbol_nonrec(stmt: &Stmt, needle: Symbol) -> bool { use crate::ir::ModifyRc::*; match stmt { Stmt::Let(_, expr, _, _) => expr_contains_symbol(expr, needle), Stmt::Switch { cond_symbol, .. } => needle == *cond_symbol, Stmt::Ret(symbol) => needle == *symbol, Stmt::Refcounting(modify, _) => { matches!( modify, Inc(symbol, _) | Dec(symbol) | DecRef(symbol) if needle == *symbol ) } Stmt::Expect { condition, lookups, .. } | Stmt::ExpectFx { condition, lookups, .. } => needle == *condition || lookups.contains(&needle), Stmt::Dbg { symbol, .. } => needle == *symbol, Stmt::Join { .. } => false, Stmt::Jump(_, arguments) => arguments.contains(&needle), Stmt::Crash(symbol, _) => needle == *symbol, } }