Handle aliases at call instead

This commit is contained in:
Agus Zubiaga 2024-08-27 02:51:23 -03:00
parent b70d48fd33
commit 80770fae11
No known key found for this signature in database
3 changed files with 126 additions and 98 deletions

View file

@ -113,35 +113,21 @@ impl Annotation {
self
}
pub fn add_arguments(&mut self, argument_count: usize, var_store: &mut VarStore) {
match self.signature {
Type::Function(ref mut arg_types, _, _) => {
arg_types.reserve(argument_count);
pub fn convert_to_fn(&mut self, argument_count: usize, var_store: &mut VarStore) {
let mut arg_types = Vec::with_capacity(argument_count);
for _ in 0..argument_count {
let var = var_store.fresh();
self.introduced_variables.insert_inferred(Loc::at_zero(var));
for _ in 0..argument_count {
let var = var_store.fresh();
self.introduced_variables.insert_inferred(Loc::at_zero(var));
arg_types.push(Type::Variable(var));
}
}
_ => {
let mut arg_types = Vec::with_capacity(argument_count);
for _ in 0..argument_count {
let var = var_store.fresh();
self.introduced_variables.insert_inferred(Loc::at_zero(var));
arg_types.push(Type::Variable(var));
}
self.signature = Type::Function(
arg_types,
Box::new(Type::Variable(var_store.fresh())),
Box::new(self.signature.clone()),
);
}
arg_types.push(Type::Variable(var));
}
self.signature = Type::Function(
arg_types,
Box::new(Type::Variable(var_store.fresh())),
Box::new(self.signature.clone()),
);
}
}

View file

@ -3135,11 +3135,11 @@ impl Declarations {
Index::push_new(&mut self.function_bodies, loc_function_def);
if let Some(annotation) = &mut self.annotations[index] {
annotation.add_arguments(new_args_len, var_store);
annotation.convert_to_fn(new_args_len, var_store);
}
if let Some((_var, annotation)) = self.host_exposed_annotations.get_mut(&index) {
annotation.add_arguments(new_args_len, var_store);
annotation.convert_to_fn(new_args_len, var_store);
}
self.declarations[index] = DeclarationTag::Function(function_def_index);

View file

@ -11,7 +11,10 @@ use roc_can::{
use roc_collections::VecMap;
use roc_module::symbol::{IdentId, IdentIds, ModuleId, Symbol};
use roc_region::all::Loc;
use roc_types::subs::{VarStore, Variable};
use roc_types::{
subs::{VarStore, Variable},
types::Type,
};
struct LowerParams<'a> {
home_id: ModuleId,
@ -55,21 +58,13 @@ impl<'a> LowerParams<'a> {
match tag {
Value => {
let aliased = self.lower_expr(true, &mut decls.expressions[index].value);
self.lower_expr(&mut decls.expressions[index].value);
if let Some(new_arg) = self.home_params_argument() {
if !aliased {
// This module has params, and this is a top-level value,
// so we need to convert it into a function that takes them.
// This module has params, and this is a top-level value,
// so we need to convert it into a function that takes them.
decls.convert_value_to_function(index, vec![new_arg], self.var_store);
} else {
// This value def is just aliasing another params extended def,
// we only need to fix the annotation
if let Some(ann) = &mut decls.annotations[index] {
ann.add_arguments(1, self.var_store);
}
}
decls.convert_value_to_function(index, vec![new_arg], self.var_store);
}
}
Function(fn_def_index) | Recursive(fn_def_index) | TailRecursive(fn_def_index) => {
@ -83,24 +78,25 @@ impl<'a> LowerParams<'a> {
.push((var, mark, pattern));
if let Some(ann) = &mut decls.annotations[index] {
ann.add_arguments(1, self.var_store);
if let Type::Function(args, _, _) = &mut ann.signature {
args.push(Type::Variable(var));
}
}
}
self.lower_expr(false, &mut decls.expressions[index].value);
self.lower_expr(&mut decls.expressions[index].value);
}
Destructure(_) | Expectation | ExpectationFx => {
self.lower_expr(false, &mut decls.expressions[index].value);
self.lower_expr(&mut decls.expressions[index].value);
}
MutualRecursion { .. } => {}
}
}
}
fn lower_expr(&mut self, is_value_def: bool, expr: &mut Expr) -> bool {
fn lower_expr(&mut self, expr: &mut Expr) {
let mut expr_stack = vec![expr];
let mut aliased = false;
while let Some(expr) = expr_stack.pop() {
match expr {
@ -113,11 +109,9 @@ impl<'a> LowerParams<'a> {
} => {
// The module was imported with params, but it might not actually expect them.
// We should only lower if it does to prevent confusing type errors.
if let Some(params) = self.imported_params.get(&symbol.module_id()) {
let arity = params.arity_by_name.get(&symbol.ident_id()).unwrap();
if let Some(arity) = self.get_imported_def_arity(symbol) {
*expr = self.lower_naked_params_var(
*arity,
arity,
*symbol,
*var,
*params_symbol,
@ -127,12 +121,6 @@ impl<'a> LowerParams<'a> {
}
Var(symbol, var) => {
if let Some((params, arity)) = self.params_extended_home_symbol(symbol) {
if is_value_def {
// Aliased top-level def, no need to lower
aliased = true;
continue;
}
*expr = self.lower_naked_params_var(
arity,
*symbol,
@ -154,23 +142,52 @@ impl<'a> LowerParams<'a> {
} => {
// Calling an imported function with params
// Extend arguments only if the imported module actually expects params
if self.imported_params.contains_key(&symbol.module_id()) {
args.push((
params_var,
Loc::at_zero(Var(params_symbol, params_var)),
));
}
match self.get_imported_def_arity(&symbol) {
Some(0) => {
// We are calling a function but the top-level declaration has no arguments.
// This can either be a function alias or a top-level def that returns functions
// under multiple branches.
// We call the value def with params, and apply the returned function to the original arguments.
fun.1.value = self.call_value_def_with_params(
symbol,
var,
params_symbol,
params_var,
);
}
Some(_) => {
// The module expects params and they were provided, we need to extend the call.
fun.1.value = Var(symbol, var);
fun.1.value = Var(symbol, var);
args.push((
params_var,
Loc::at_zero(Var(params_symbol, params_var)),
));
}
None => {
// The module expects no params, do not extend to prevent confusing type errors.
fun.1.value = Var(symbol, var);
}
}
}
Var(symbol, _var) => {
if let Some((params, _)) = self.params_extended_home_symbol(&symbol) {
// Calling a top-level function in the current module with params
args.push((
params.whole_var,
Loc::at_zero(Var(params.whole_symbol, params.whole_var)),
));
if let Some((params, arity)) = self.params_extended_home_symbol(&symbol)
{
if arity == 0 {
// Calling the result of a top-level value def in the current module
fun.1.value = self.call_value_def_with_params(
symbol,
params.whole_var,
params.whole_symbol,
params.whole_var,
);
} else {
// Calling a top-level function in the current module with params
args.push((
params.whole_var,
Loc::at_zero(Var(params.whole_symbol, params.whole_var)),
));
}
}
}
_ => expr_stack.push(&mut fun.1.value),
@ -370,14 +387,26 @@ impl<'a> LowerParams<'a> {
| AbilityMember(_, _, _) => { /* terminal */ }
}
}
aliased
}
fn unique_symbol(&mut self) -> Symbol {
Symbol::new(self.home_id, self.ident_ids.gen_unique())
}
fn home_params_argument(&mut self) -> Option<(Variable, AnnotatedMark, Loc<Pattern>)> {
match &self.home_params {
Some(module_params) => {
let new_var = self.var_store.fresh();
Some((
new_var,
AnnotatedMark::new(self.var_store),
module_params.pattern(),
))
}
None => None,
}
}
fn params_extended_home_symbol(&self, symbol: &Symbol) -> Option<(&ModuleParams, usize)> {
if symbol.module_id() == self.home_id {
match self.home_params {
@ -392,6 +421,12 @@ impl<'a> LowerParams<'a> {
}
}
fn get_imported_def_arity(&self, symbol: &Symbol) -> Option<usize> {
self.imported_params
.get(&symbol.module_id())
.and_then(|params| params.arity_by_name.get(&symbol.ident_id()).copied())
}
fn lower_naked_params_var(
&mut self,
arity: usize,
@ -400,14 +435,6 @@ impl<'a> LowerParams<'a> {
params_symbol: Symbol,
params_var: Variable,
) -> Expr {
let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var)));
let call_fn = Box::new((
self.var_store.fresh(),
Loc::at_zero(Var(symbol, var)),
self.var_store.fresh(),
self.var_store.fresh(),
));
if arity == 0 {
// We are passing a top-level value that takes params, so we need to replace the Var
// with a call that passes the params to get the final result.
@ -416,12 +443,7 @@ impl<'a> LowerParams<'a> {
// record = \... #params -> { doubled: value }
// ↓
// value #params
Call(
call_fn,
vec![params_arg],
// todo: custom called via
roc_module::called_via::CalledVia::Space,
)
self.call_value_def_with_params(symbol, var, params_symbol, params_var)
} else {
// We are passing a top-level function that takes params, so we need to replace
// the Var with a closure that captures the params and passes them to the function.
@ -446,8 +468,17 @@ impl<'a> LowerParams<'a> {
call_arguments.push((var, Loc::at_zero(Var(sym, var))));
}
let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var)));
call_arguments.push(params_arg);
let call_fn = Box::new((
self.var_store.fresh(),
Loc::at_zero(Var(symbol, var)),
self.var_store.fresh(),
self.var_store.fresh(),
));
let body = Call(
call_fn,
call_arguments,
@ -475,17 +506,28 @@ impl<'a> LowerParams<'a> {
})
}
}
fn home_params_argument(&mut self) -> Option<(Variable, AnnotatedMark, Loc<Pattern>)> {
match &self.home_params {
Some(module_params) => {
let new_var = self.var_store.fresh();
Some((
new_var,
AnnotatedMark::new(self.var_store),
module_params.pattern(),
))
}
None => None,
}
fn call_value_def_with_params(
&mut self,
symbol: Symbol,
var: Variable,
params_symbol: Symbol,
params_var: Variable,
) -> Expr {
let params_arg = (params_var, Loc::at_zero(Var(params_symbol, params_var)));
let call_fn = Box::new((
self.var_store.fresh(),
Loc::at_zero(Var(symbol, var)),
self.var_store.fresh(),
self.var_store.fresh(),
));
Call(
call_fn,
vec![params_arg],
// todo: custom called via
roc_module::called_via::CalledVia::Space,
)
}
}