diff --git a/crates/compiler/can/src/def.rs b/crates/compiler/can/src/def.rs index 6af033064b..625c7113bb 100644 --- a/crates/compiler/can/src/def.rs +++ b/crates/compiler/can/src/def.rs @@ -113,35 +113,21 @@ impl Annotation { self } - pub fn add_arguments(&mut self, argument_count: usize, var_store: &mut VarStore) { - match self.signature { - Type::Function(ref mut arg_types, _, _) => { - arg_types.reserve(argument_count); + pub fn convert_to_fn(&mut self, argument_count: usize, var_store: &mut VarStore) { + let mut arg_types = Vec::with_capacity(argument_count); - for _ in 0..argument_count { - let var = var_store.fresh(); - self.introduced_variables.insert_inferred(Loc::at_zero(var)); + for _ in 0..argument_count { + let var = var_store.fresh(); + self.introduced_variables.insert_inferred(Loc::at_zero(var)); - arg_types.push(Type::Variable(var)); - } - } - _ => { - let mut arg_types = Vec::with_capacity(argument_count); - - for _ in 0..argument_count { - let var = var_store.fresh(); - self.introduced_variables.insert_inferred(Loc::at_zero(var)); - - arg_types.push(Type::Variable(var)); - } - - self.signature = Type::Function( - arg_types, - Box::new(Type::Variable(var_store.fresh())), - Box::new(self.signature.clone()), - ); - } + arg_types.push(Type::Variable(var)); } + + self.signature = Type::Function( + arg_types, + Box::new(Type::Variable(var_store.fresh())), + Box::new(self.signature.clone()), + ); } } diff --git a/crates/compiler/can/src/expr.rs b/crates/compiler/can/src/expr.rs index 55fcb7c28f..35315d4f62 100644 --- a/crates/compiler/can/src/expr.rs +++ b/crates/compiler/can/src/expr.rs @@ -3135,11 +3135,11 @@ impl Declarations { Index::push_new(&mut self.function_bodies, loc_function_def); if let Some(annotation) = &mut self.annotations[index] { - annotation.add_arguments(new_args_len, var_store); + annotation.convert_to_fn(new_args_len, var_store); } if let Some((_var, annotation)) = self.host_exposed_annotations.get_mut(&index) { - annotation.add_arguments(new_args_len, var_store); + annotation.convert_to_fn(new_args_len, var_store); } self.declarations[index] = DeclarationTag::Function(function_def_index); diff --git a/crates/compiler/lower_params/src/lower.rs b/crates/compiler/lower_params/src/lower.rs index 60da9fbea9..e137e9acb2 100644 --- a/crates/compiler/lower_params/src/lower.rs +++ b/crates/compiler/lower_params/src/lower.rs @@ -11,7 +11,10 @@ use roc_can::{ use roc_collections::VecMap; use roc_module::symbol::{IdentId, IdentIds, ModuleId, Symbol}; use roc_region::all::Loc; -use roc_types::subs::{VarStore, Variable}; +use roc_types::{ + subs::{VarStore, Variable}, + types::Type, +}; struct LowerParams<'a> { home_id: ModuleId, @@ -55,21 +58,13 @@ impl<'a> LowerParams<'a> { match tag { Value => { - let aliased = self.lower_expr(true, &mut decls.expressions[index].value); + self.lower_expr(&mut decls.expressions[index].value); if let Some(new_arg) = self.home_params_argument() { - if !aliased { - // This module has params, and this is a top-level value, - // so we need to convert it into a function that takes them. + // This module has params, and this is a top-level value, + // so we need to convert it into a function that takes them. - decls.convert_value_to_function(index, vec![new_arg], self.var_store); - } else { - // This value def is just aliasing another params extended def, - // we only need to fix the annotation - if let Some(ann) = &mut decls.annotations[index] { - ann.add_arguments(1, self.var_store); - } - } + decls.convert_value_to_function(index, vec![new_arg], self.var_store); } } Function(fn_def_index) | Recursive(fn_def_index) | TailRecursive(fn_def_index) => { @@ -83,24 +78,25 @@ impl<'a> LowerParams<'a> { .push((var, mark, pattern)); if let Some(ann) = &mut decls.annotations[index] { - ann.add_arguments(1, self.var_store); + if let Type::Function(args, _, _) = &mut ann.signature { + args.push(Type::Variable(var)); + } } } - self.lower_expr(false, &mut decls.expressions[index].value); + self.lower_expr(&mut decls.expressions[index].value); } Destructure(_) | Expectation | ExpectationFx => { - self.lower_expr(false, &mut decls.expressions[index].value); + self.lower_expr(&mut decls.expressions[index].value); } MutualRecursion { .. } => {} } } } - fn lower_expr(&mut self, is_value_def: bool, expr: &mut Expr) -> bool { + fn lower_expr(&mut self, expr: &mut Expr) { let mut expr_stack = vec![expr]; - let mut aliased = false; while let Some(expr) = expr_stack.pop() { match expr { @@ -113,11 +109,9 @@ impl<'a> LowerParams<'a> { } => { // The module was imported with params, but it might not actually expect them. // We should only lower if it does to prevent confusing type errors. - if let Some(params) = self.imported_params.get(&symbol.module_id()) { - let arity = params.arity_by_name.get(&symbol.ident_id()).unwrap(); - + if let Some(arity) = self.get_imported_def_arity(symbol) { *expr = self.lower_naked_params_var( - *arity, + arity, *symbol, *var, *params_symbol, @@ -127,12 +121,6 @@ impl<'a> LowerParams<'a> { } Var(symbol, var) => { if let Some((params, arity)) = self.params_extended_home_symbol(symbol) { - if is_value_def { - // Aliased top-level def, no need to lower - aliased = true; - continue; - } - *expr = self.lower_naked_params_var( arity, *symbol, @@ -154,23 +142,52 @@ impl<'a> LowerParams<'a> { } => { // Calling an imported function with params - // Extend arguments only if the imported module actually expects params - if self.imported_params.contains_key(&symbol.module_id()) { - args.push(( - params_var, - Loc::at_zero(Var(params_symbol, params_var)), - )); - } + match self.get_imported_def_arity(&symbol) { + Some(0) => { + // We are calling a function but the top-level declaration has no arguments. + // This can either be a function alias or a top-level def that returns functions + // under multiple branches. + // We call the value def with params, and apply the returned function to the original arguments. + fun.1.value = self.call_value_def_with_params( + symbol, + var, + params_symbol, + params_var, + ); + } + Some(_) => { + // The module expects params and they were provided, we need to extend the call. + fun.1.value = Var(symbol, var); - fun.1.value = Var(symbol, var); + args.push(( + params_var, + Loc::at_zero(Var(params_symbol, params_var)), + )); + } + None => { + // The module expects no params, do not extend to prevent confusing type errors. + fun.1.value = Var(symbol, var); + } + } } Var(symbol, _var) => { - if let Some((params, _)) = self.params_extended_home_symbol(&symbol) { - // Calling a top-level function in the current module with params - args.push(( - params.whole_var, - Loc::at_zero(Var(params.whole_symbol, params.whole_var)), - )); + if let Some((params, arity)) = self.params_extended_home_symbol(&symbol) + { + if arity == 0 { + // Calling the result of a top-level value def in the current module + fun.1.value = self.call_value_def_with_params( + symbol, + params.whole_var, + params.whole_symbol, + params.whole_var, + ); + } else { + // Calling a top-level function in the current module with params + args.push(( + params.whole_var, + Loc::at_zero(Var(params.whole_symbol, params.whole_var)), + )); + } } } _ => expr_stack.push(&mut fun.1.value), @@ -370,14 +387,26 @@ impl<'a> LowerParams<'a> { | AbilityMember(_, _, _) => { /* terminal */ } } } - - aliased } fn unique_symbol(&mut self) -> Symbol { Symbol::new(self.home_id, self.ident_ids.gen_unique()) } + fn home_params_argument(&mut self) -> Option<(Variable, AnnotatedMark, Loc)> { + match &self.home_params { + Some(module_params) => { + let new_var = self.var_store.fresh(); + Some(( + new_var, + AnnotatedMark::new(self.var_store), + module_params.pattern(), + )) + } + None => None, + } + } + fn params_extended_home_symbol(&self, symbol: &Symbol) -> Option<(&ModuleParams, usize)> { if symbol.module_id() == self.home_id { match self.home_params { @@ -392,6 +421,12 @@ impl<'a> LowerParams<'a> { } } + fn get_imported_def_arity(&self, symbol: &Symbol) -> Option { + self.imported_params + .get(&symbol.module_id()) + .and_then(|params| params.arity_by_name.get(&symbol.ident_id()).copied()) + } + fn lower_naked_params_var( &mut self, arity: usize, @@ -400,14 +435,6 @@ impl<'a> LowerParams<'a> { params_symbol: Symbol, params_var: Variable, ) -> Expr { - let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var))); - let call_fn = Box::new(( - self.var_store.fresh(), - Loc::at_zero(Var(symbol, var)), - self.var_store.fresh(), - self.var_store.fresh(), - )); - if arity == 0 { // We are passing a top-level value that takes params, so we need to replace the Var // with a call that passes the params to get the final result. @@ -416,12 +443,7 @@ impl<'a> LowerParams<'a> { // record = \... #params -> { doubled: value } // ↓ // value #params - Call( - call_fn, - vec![params_arg], - // todo: custom called via - roc_module::called_via::CalledVia::Space, - ) + self.call_value_def_with_params(symbol, var, params_symbol, params_var) } else { // We are passing a top-level function that takes params, so we need to replace // the Var with a closure that captures the params and passes them to the function. @@ -446,8 +468,17 @@ impl<'a> LowerParams<'a> { call_arguments.push((var, Loc::at_zero(Var(sym, var)))); } + let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var))); + call_arguments.push(params_arg); + let call_fn = Box::new(( + self.var_store.fresh(), + Loc::at_zero(Var(symbol, var)), + self.var_store.fresh(), + self.var_store.fresh(), + )); + let body = Call( call_fn, call_arguments, @@ -475,17 +506,28 @@ impl<'a> LowerParams<'a> { }) } } - fn home_params_argument(&mut self) -> Option<(Variable, AnnotatedMark, Loc)> { - match &self.home_params { - Some(module_params) => { - let new_var = self.var_store.fresh(); - Some(( - new_var, - AnnotatedMark::new(self.var_store), - module_params.pattern(), - )) - } - None => None, - } + + fn call_value_def_with_params( + &mut self, + symbol: Symbol, + var: Variable, + params_symbol: Symbol, + params_var: Variable, + ) -> Expr { + let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var))); + + let call_fn = Box::new(( + self.var_store.fresh(), + Loc::at_zero(Var(symbol, var)), + self.var_store.fresh(), + self.var_store.fresh(), + )); + + Call( + call_fn, + vec![params_arg], + // todo: custom called via + roc_module::called_via::CalledVia::Space, + ) } }