use roc_can::{ expr::{ AnnotatedMark, ClosureData, DeclarationTag::*, Declarations, Expr::{self, *}, }, module::ModuleParams, pattern::{Pattern, RecordDestruct}, }; use roc_collections::VecMap; use roc_module::symbol::{IdentId, IdentIds, ModuleId, Symbol}; use roc_region::all::Loc; use roc_types::{ subs::{VarStore, Variable}, types::Type, }; use std::iter::once; struct LowerParams<'a> { home_id: ModuleId, home_params: &'a Option, imported_params: VecMap, var_store: &'a mut VarStore, ident_ids: &'a mut IdentIds, top_level_idents: Vec, } pub fn lower( home_id: ModuleId, home_params: &Option, imported_params: VecMap, decls: &mut Declarations, ident_ids: &mut IdentIds, var_store: &mut VarStore, ) { let top_level_idents = decls .symbols .iter() .map(|symbol| symbol.value.ident_id()) .collect(); let mut env = LowerParams { home_id, home_params, imported_params, ident_ids, var_store, top_level_idents, }; env.lower_decls(decls); } impl<'a> LowerParams<'a> { fn lower_decls(&mut self, decls: &mut Declarations) { let home_param_symbols = match self.home_params { Some(params) => params .destructs .iter() .map(|destruct| destruct.value.symbol) .chain(once(params.whole_symbol)) .collect::>(), None => vec![], }; for index in 0..decls.len() { let tag = decls.declarations[index]; match tag { Value => { self.lower_expr(&mut decls.expressions[index].value); if let Some(new_arg) = self.home_params_argument() { // This module has params, and this is a top-level value, // so we need to convert it into a function that takes them. decls.convert_value_to_function(index, vec![new_arg], self.var_store); } } Function(fn_def_index) | Recursive(fn_def_index) | TailRecursive(fn_def_index) => { if let Some((var, mark, pattern)) = self.home_params_argument() { // This module has params, and this is a top-level function, // so we need to extend its definition to take them. let function_body = &mut decls.function_bodies[fn_def_index.index()].value; function_body.arguments.push((var, mark, pattern)); // Remove home params from the captured symbols, only nested lambdas need them. function_body .captured_symbols .retain(|(sym, _)| !home_param_symbols.contains(sym)); if let Some(ann) = &mut decls.annotations[index] { if let Type::Function(args, _, _) = &mut ann.signature { args.push(Type::Variable(var)); } } } self.lower_expr(&mut decls.expressions[index].value); } Destructure(_) | Expectation | ExpectationFx => { self.lower_expr(&mut decls.expressions[index].value); } MutualRecursion { .. } => {} } } } fn lower_expr(&mut self, expr: &mut Expr) { let mut expr_stack = vec![expr]; while let Some(expr) = expr_stack.pop() { match expr { // Nodes to lower ParamsVar { symbol, var, params_symbol, params_var, } => { // The module was imported with params, but it might not actually expect them. // We should only lower if it does to prevent confusing type errors. if let Some(arity) = self.get_imported_def_arity(symbol) { *expr = self.lower_naked_params_var( arity, *symbol, *var, *params_symbol, *params_var, ); } } Var(symbol, var) => { if let Some((params, arity)) = self.params_extended_home_symbol(symbol) { *expr = self.lower_naked_params_var( arity, *symbol, *var, params.whole_symbol, params.whole_var, ); } } Call(fun, args, _called_via) => { expr_stack.reserve(args.len() + 1); match fun.1.value { ParamsVar { symbol, var, params_var, params_symbol, } => { // Calling an imported function with params match self.get_imported_def_arity(&symbol) { Some(0) => { // We are calling a function but the top-level declaration has no arguments. // This can either be a function alias or a top-level def that returns functions // under multiple branches. // We call the value def with params, and apply the returned function to the original arguments. fun.1.value = self.call_value_def_with_params( symbol, var, params_symbol, params_var, ); } Some(_) => { // The module expects params and they were provided, we need to extend the call. fun.1.value = Var(symbol, var); args.push(( params_var, Loc::at_zero(Var(params_symbol, params_var)), )); } None => { // The module expects no params, do not extend to prevent confusing type errors. fun.1.value = Var(symbol, var); } } } Var(symbol, var) => { if let Some((params, arity)) = self.params_extended_home_symbol(&symbol) { if arity == 0 { // Calling the result of a top-level value def in the current module fun.1.value = self.call_value_def_with_params( symbol, var, params.whole_symbol, params.whole_var, ); } else { // Calling a top-level function in the current module with params args.push(( params.whole_var, Loc::at_zero(Var(params.whole_symbol, params.whole_var)), )); } } } _ => expr_stack.push(&mut fun.1.value), } for (_, arg) in args.iter_mut() { expr_stack.push(&mut arg.value); } } // Nodes to walk Closure(ClosureData { loc_body, captured_symbols: _, name: _, function_type: _, closure_type: _, return_type: _, recursive: _, arguments: _, }) => { expr_stack.push(&mut loc_body.value); } LetNonRec(def, cont) => { expr_stack.reserve(2); expr_stack.push(&mut def.loc_expr.value); expr_stack.push(&mut cont.value); } LetRec(defs, cont, _cycle_mark) => { expr_stack.reserve(defs.len() + 1); for def in defs { expr_stack.push(&mut def.loc_expr.value); } expr_stack.push(&mut cont.value); } When { loc_cond, branches, cond_var: _, expr_var: _, region: _, branches_cond_var: _, exhaustive: _, } => { expr_stack.reserve(branches.len() + 1); expr_stack.push(&mut loc_cond.value); for branch in branches.iter_mut() { expr_stack.push(&mut branch.value.value); } } If { branches, final_else, cond_var: _, branch_var: _, } => { expr_stack.reserve(branches.len() * 2 + 1); for (cond, ret) in branches.iter_mut() { expr_stack.push(&mut cond.value); expr_stack.push(&mut ret.value); } expr_stack.push(&mut final_else.value); } RunLowLevel { args, op: _, ret_var: _, } | ForeignCall { foreign_symbol: _, args, ret_var: _, } => { expr_stack.extend(args.iter_mut().map(|(_, arg)| arg)); } List { elem_var: _, loc_elems, } => { expr_stack.extend(loc_elems.iter_mut().map(|loc_elem| &mut loc_elem.value)); } Record { record_var: _, fields, } => { expr_stack.extend( fields .iter_mut() .map(|(_, field)| &mut field.loc_expr.value), ); } Tuple { tuple_var: _, elems, } => { expr_stack.extend(elems.iter_mut().map(|(_, elem)| &mut elem.value)); } ImportParams(_, _, Some((_, params_expr))) => { expr_stack.push(params_expr); } Crash { msg, ret_var: _ } => { expr_stack.push(&mut msg.value); } RecordAccess { loc_expr, record_var: _, ext_var: _, field_var: _, field: _, } => expr_stack.push(&mut loc_expr.value), TupleAccess { loc_expr, tuple_var: _, ext_var: _, elem_var: _, index: _, } => expr_stack.push(&mut loc_expr.value), RecordUpdate { updates, record_var: _, ext_var: _, symbol: _, } => expr_stack.extend( updates .iter_mut() .map(|(_, field)| &mut field.loc_expr.value), ), Tag { arguments, tag_union_var: _, ext_var: _, name: _, } => expr_stack.extend(arguments.iter_mut().map(|(_, arg)| &mut arg.value)), OpaqueRef { argument, opaque_var: _, name: _, specialized_def_type: _, type_arguments: _, lambda_set_variables: _, } => expr_stack.push(&mut argument.1.value), Expect { loc_condition, loc_continuation, lookups_in_cond: _, } => { expr_stack.reserve(2); expr_stack.push(&mut loc_condition.value); expr_stack.push(&mut loc_continuation.value); } ExpectFx { loc_condition, loc_continuation, lookups_in_cond: _, } => { expr_stack.reserve(2); expr_stack.push(&mut loc_condition.value); expr_stack.push(&mut loc_continuation.value); } Dbg { loc_message, loc_continuation, source_location: _, source: _, variable: _, symbol: _, } => { expr_stack.reserve(2); expr_stack.push(&mut loc_message.value); expr_stack.push(&mut loc_continuation.value); } RecordAccessor(_) | ImportParams(_, _, None) | ZeroArgumentTag { closure_name: _, variant_var: _, ext_var: _, name: _, } | OpaqueWrapFunction(_) | EmptyRecord | TypedHole(_) | RuntimeError(_) | Num(_, _, _, _) | Int(_, _, _, _, _) | Float(_, _, _, _, _) | Str(_) | SingleQuote(_, _, _, _) | IngestedFile(_, _, _) | AbilityMember(_, _, _) => { /* terminal */ } } } } fn unique_symbol(&mut self) -> Symbol { Symbol::new(self.home_id, self.ident_ids.gen_unique()) } fn home_params_argument(&mut self) -> Option<(Variable, AnnotatedMark, Loc)> { match &self.home_params { Some(module_params) => { let destructs: Vec> = module_params .destructs .iter() .map(|destructure| { destructure.map(|d| RecordDestruct { symbol: d.symbol, var: self.var_store.fresh(), label: d.label.clone(), typ: d.typ.clone(), }) }) .collect(); let record_pattern = Pattern::RecordDestructure { whole_var: module_params.record_var, ext_var: module_params.record_ext_var, destructs, }; let loc_record_pattern = Loc::at(module_params.region, record_pattern); let as_pattern = Pattern::As(Box::new(loc_record_pattern), module_params.whole_symbol); let loc_pattern = Loc::at(module_params.region, as_pattern); Some(( self.var_store.fresh(), AnnotatedMark::new(self.var_store), loc_pattern, )) } None => None, } } fn params_extended_home_symbol(&self, symbol: &Symbol) -> Option<(&ModuleParams, usize)> { if symbol.module_id() == self.home_id { self.home_params.as_ref().and_then(|params| { params .arity_by_name .get(&symbol.ident_id()) .map(|arity| (params, *arity)) }) } else { None } } fn get_imported_def_arity(&self, symbol: &Symbol) -> Option { self.imported_params .get(&symbol.module_id()) .and_then(|params| params.arity_by_name.get(&symbol.ident_id()).copied()) } fn lower_naked_params_var( &mut self, arity: usize, symbol: Symbol, var: Variable, params_symbol: Symbol, params_var: Variable, ) -> Expr { if arity == 0 { // We are passing a top-level value that takes params, so we need to replace the Var // with a call that passes the params to get the final result. // // value = \#params -> #params.x * 2 // record = \... #params -> { doubled: value } // ↓ // value #params self.call_value_def_with_params(symbol, var, params_symbol, params_var) } else { // We are passing a top-level function that takes params, so we need to replace // the Var with a closure that captures the params and passes them to the function. // // fn1 = \arg #params -> #params.x * arg // fn2 = \... #params -> List.map [1, 2] fn1 // ↓ // (\#1 -> fn1 #1 #params) // let mut arguments = Vec::with_capacity(arity); let mut call_arguments = Vec::with_capacity(arity + 1); for _ in 0..arity { let sym = self.unique_symbol(); let var = self.var_store.fresh(); arguments.push(( var, AnnotatedMark::new(self.var_store), Loc::at_zero(Pattern::Identifier(sym)), )); call_arguments.push((var, Loc::at_zero(Var(sym, var)))); } let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var))); call_arguments.push(params_arg); let call_fn = Box::new(( self.var_store.fresh(), Loc::at_zero(Var(symbol, var)), self.var_store.fresh(), self.var_store.fresh(), )); let body = Call( call_fn, call_arguments, roc_module::called_via::CalledVia::NakedParamsVar, ); let captured_symbols = if symbol.module_id() == self.home_id || !self.top_level_idents.contains(¶ms_symbol.ident_id()) { vec![(params_symbol, params_var)] } else { vec![] }; Closure(ClosureData { function_type: self.var_store.fresh(), closure_type: self.var_store.fresh(), return_type: self.var_store.fresh(), name: self.unique_symbol(), captured_symbols, recursive: roc_can::expr::Recursive::NotRecursive, arguments, loc_body: Box::new(Loc::at_zero(body)), }) } } fn call_value_def_with_params( &mut self, symbol: Symbol, var: Variable, params_symbol: Symbol, params_var: Variable, ) -> Expr { let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var))); let call_fn = Box::new(( self.var_store.fresh(), Loc::at_zero(Var(symbol, var)), self.var_store.fresh(), self.var_store.fresh(), )); Call( call_fn, vec![params_arg], roc_module::called_via::CalledVia::NakedParamsVar, ) } }