Merge branch 'trunk' into list-eq

This commit is contained in:
Richard Feldman 2021-01-04 08:44:30 -05:00 committed by GitHub
commit fb95c72127
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1165 additions and 487 deletions

View file

@ -18,6 +18,7 @@ use crate::llvm::refcounting::{
}; };
use bumpalo::collections::Vec; use bumpalo::collections::Vec;
use bumpalo::Bump; use bumpalo::Bump;
use either::Either;
use inkwell::basic_block::BasicBlock; use inkwell::basic_block::BasicBlock;
use inkwell::builder::Builder; use inkwell::builder::Builder;
use inkwell::context::Context; use inkwell::context::Context;
@ -40,7 +41,7 @@ use roc_collections::all::{ImMap, MutSet};
use roc_module::ident::TagName; use roc_module::ident::TagName;
use roc_module::low_level::LowLevel; use roc_module::low_level::LowLevel;
use roc_module::symbol::{Interns, ModuleId, Symbol}; use roc_module::symbol::{Interns, ModuleId, Symbol};
use roc_mono::ir::{JoinPointId, Wrapped}; use roc_mono::ir::{CallType, JoinPointId, Wrapped};
use roc_mono::layout::{Builtin, ClosureLayout, Layout, LayoutIds, MemoryMode}; use roc_mono::layout::{Builtin, ClosureLayout, Layout, LayoutIds, MemoryMode};
use target_lexicon::CallingConvention; use target_lexicon::CallingConvention;
@ -429,6 +430,14 @@ pub fn construct_optimization_passes<'a>(
fpm.add_memcpy_optimize_pass(); // this one is very important fpm.add_memcpy_optimize_pass(); // this one is very important
fpm.add_licm_pass(); fpm.add_licm_pass();
// turn invoke into call
mpm.add_prune_eh_pass();
// remove unused global values (often the `_wrapper` can be removed)
mpm.add_global_dce_pass();
mpm.add_function_inlining_pass();
} }
} }
@ -610,26 +619,81 @@ pub fn build_exp_literal<'a, 'ctx, 'env>(
} }
} }
pub fn build_exp_expr<'a, 'ctx, 'env>( pub fn build_exp_call<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>, layout_ids: &mut LayoutIds<'a>,
scope: &Scope<'a, 'ctx>, scope: &Scope<'a, 'ctx>,
parent: FunctionValue<'ctx>, parent: FunctionValue<'ctx>,
layout: &Layout<'a>, layout: &Layout<'a>,
expr: &roc_mono::ir::Expr<'a>, call: &roc_mono::ir::Call<'a>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
use roc_mono::ir::CallType::*; let roc_mono::ir::Call {
use roc_mono::ir::Expr::*; call_type,
arguments,
} = call;
match expr { match call_type {
Literal(literal) => build_exp_literal(env, literal), CallType::ByName {
RunLowLevel(op, symbols) => { name, full_layout, ..
run_low_level(env, layout_ids, scope, parent, layout, *op, symbols) } => {
let mut arg_tuples: Vec<BasicValueEnum> =
Vec::with_capacity_in(arguments.len(), env.arena);
for symbol in arguments.iter() {
arg_tuples.push(load_symbol(env, scope, symbol));
}
call_with_args(
env,
layout_ids,
&full_layout,
*name,
parent,
arg_tuples.into_bump_slice(),
)
} }
ForeignCall { CallType::ByPointer { name, .. } => {
let sub_expr = load_symbol(env, scope, name);
let mut arg_vals: Vec<BasicValueEnum> =
Vec::with_capacity_in(arguments.len(), env.arena);
for arg in arguments.iter() {
arg_vals.push(load_symbol(env, scope, arg));
}
let call = match sub_expr {
BasicValueEnum::PointerValue(ptr) => {
env.builder.build_call(ptr, arg_vals.as_slice(), "tmp")
}
non_ptr => {
panic!(
"Tried to call by pointer, but encountered a non-pointer: {:?}",
non_ptr
);
}
};
if env.exposed_to_host.contains(name) {
// If this is an external-facing function, use the C calling convention.
call.set_call_convention(C_CALL_CONV);
} else {
// If it's an internal-only function, use the fast calling convention.
call.set_call_convention(FAST_CALL_CONV);
}
call.try_as_basic_value()
.left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."))
}
CallType::LowLevel { op } => {
run_low_level(env, layout_ids, scope, parent, layout, *op, arguments)
}
CallType::Foreign {
foreign_symbol, foreign_symbol,
arguments,
ret_layout, ret_layout,
} => { } => {
let mut arg_vals: Vec<BasicValueEnum> = let mut arg_vals: Vec<BasicValueEnum> =
@ -696,65 +760,23 @@ pub fn build_exp_expr<'a, 'ctx, 'env>(
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."))
} }
} }
FunctionCall { }
call_type: ByName(name), }
full_layout,
args,
..
} => {
let mut arg_tuples: Vec<BasicValueEnum> = Vec::with_capacity_in(args.len(), env.arena);
for symbol in args.iter() { pub fn build_exp_expr<'a, 'ctx, 'env>(
arg_tuples.push(load_symbol(env, scope, symbol)); env: &Env<'a, 'ctx, 'env>,
} layout_ids: &mut LayoutIds<'a>,
scope: &Scope<'a, 'ctx>,
parent: FunctionValue<'ctx>,
layout: &Layout<'a>,
expr: &roc_mono::ir::Expr<'a>,
) -> BasicValueEnum<'ctx> {
use roc_mono::ir::Expr::*;
call_with_args( match expr {
env, Literal(literal) => build_exp_literal(env, literal),
layout_ids,
&full_layout,
*name,
parent,
arg_tuples.into_bump_slice(),
)
}
FunctionCall { Call(call) => build_exp_call(env, layout_ids, scope, parent, layout, call),
call_type: ByPointer(name),
args,
..
} => {
let sub_expr = load_symbol(env, scope, name);
let mut arg_vals: Vec<BasicValueEnum> = Vec::with_capacity_in(args.len(), env.arena);
for arg in args.iter() {
arg_vals.push(load_symbol(env, scope, arg));
}
let call = match sub_expr {
BasicValueEnum::PointerValue(ptr) => {
env.builder.build_call(ptr, arg_vals.as_slice(), "tmp")
}
non_ptr => {
panic!(
"Tried to call by pointer, but encountered a non-pointer: {:?}",
non_ptr
);
}
};
if env.exposed_to_host.contains(name) {
// If this is an external-facing function, use the C calling convention.
call.set_call_convention(C_CALL_CONV);
} else {
// If it's an internal-only function, use the fast calling convention.
call.set_call_convention(FAST_CALL_CONV);
}
call.try_as_basic_value()
.left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."))
}
Struct(sorted_fields) => { Struct(sorted_fields) => {
let ctx = env.context; let ctx = env.context;
@ -1284,6 +1306,92 @@ fn list_literal<'a, 'ctx, 'env>(
) )
} }
#[allow(clippy::too_many_arguments)]
fn invoke_roc_function<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
scope: &mut Scope<'a, 'ctx>,
parent: FunctionValue<'ctx>,
symbol: Symbol,
layout: Layout<'a>,
function_value: Either<FunctionValue<'ctx>, PointerValue<'ctx>>,
arguments: &[Symbol],
pass: &'a roc_mono::ir::Stmt<'a>,
fail: &'a roc_mono::ir::Stmt<'a>,
) -> BasicValueEnum<'ctx> {
let context = env.context;
let call_bt = basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes);
let alloca = create_entry_block_alloca(env, parent, call_bt, symbol.ident_string(&env.interns));
let mut arg_vals: Vec<BasicValueEnum> = Vec::with_capacity_in(arguments.len(), env.arena);
for arg in arguments.iter() {
arg_vals.push(load_symbol(env, scope, arg));
}
let pass_block = context.append_basic_block(parent, "invoke_pass");
let fail_block = context.append_basic_block(parent, "invoke_fail");
let call_result = {
let call = env.builder.build_invoke(
function_value,
arg_vals.as_slice(),
pass_block,
fail_block,
"tmp",
);
match function_value {
Either::Left(function) => {
call.set_call_convention(function.get_call_conventions());
}
Either::Right(_) => {
call.set_call_convention(FAST_CALL_CONV);
}
}
call.try_as_basic_value()
.left()
.unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer."))
};
{
env.builder.position_at_end(pass_block);
env.builder.build_store(alloca, call_result);
scope.insert(symbol, (layout, alloca));
build_exp_stmt(env, layout_ids, scope, parent, pass);
scope.remove(&symbol);
}
{
env.builder.position_at_end(fail_block);
let landing_pad_type = {
let exception_ptr = context.i8_type().ptr_type(AddressSpace::Generic).into();
let selector_value = context.i32_type().into();
context.struct_type(&[exception_ptr, selector_value], false)
};
env.builder
.build_catch_all_landing_pad(
&landing_pad_type,
&BasicValueEnum::IntValue(context.i8_type().const_zero()),
context.i8_type().ptr_type(AddressSpace::Generic),
"invoke_landing_pad",
)
.into_struct_value();
build_exp_stmt(env, layout_ids, scope, parent, fail);
}
call_result
}
pub fn build_exp_stmt<'a, 'ctx, 'env>( pub fn build_exp_stmt<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>, env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>, layout_ids: &mut LayoutIds<'a>,
@ -1366,6 +1474,87 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
value value
} }
Invoke {
symbol,
call,
layout,
pass,
fail: roc_mono::ir::Stmt::Unreachable,
} => {
// when the fail case is just Unreachable, there is no cleanup work to do
// so we can just treat this invoke as a normal call
let stmt =
roc_mono::ir::Stmt::Let(*symbol, Expr::Call(call.clone()), layout.clone(), pass);
build_exp_stmt(env, layout_ids, scope, parent, &stmt)
}
Invoke {
symbol,
call,
layout,
pass,
fail,
} => match call.call_type {
CallType::ByName {
name,
ref full_layout,
..
} => {
let function_value = function_value_by_name(env, layout_ids, full_layout, name);
invoke_roc_function(
env,
layout_ids,
scope,
parent,
*symbol,
layout.clone(),
function_value.into(),
call.arguments,
pass,
fail,
)
}
CallType::ByPointer { name, .. } => {
let sub_expr = load_symbol(env, scope, &name);
let function_ptr = match sub_expr {
BasicValueEnum::PointerValue(ptr) => ptr,
non_ptr => {
panic!(
"Tried to call by pointer, but encountered a non-pointer: {:?}",
non_ptr
);
}
};
invoke_roc_function(
env,
layout_ids,
scope,
parent,
*symbol,
layout.clone(),
function_ptr.into(),
call.arguments,
pass,
fail,
)
}
_ => {
todo!()
}
},
Unreachable => {
cxa_rethrow_exception(env);
// used in exception handling
env.builder.build_unreachable();
env.context.i64_type().const_zero().into()
}
Switch { Switch {
branches, branches,
default_branch, default_branch,
@ -2012,7 +2201,11 @@ fn make_exception_catcher<'a, 'ctx, 'env>(
) -> FunctionValue<'ctx> { ) -> FunctionValue<'ctx> {
let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap()); let wrapper_function_name = format!("{}_catcher", roc_function.get_name().to_str().unwrap());
make_exception_catching_wrapper(env, roc_function, &wrapper_function_name) let function_value = make_exception_catching_wrapper(env, roc_function, &wrapper_function_name);
function_value.set_linkage(Linkage::Internal);
function_value
} }
fn make_exception_catching_wrapper<'a, 'ctx, 'env>( fn make_exception_catching_wrapper<'a, 'ctx, 'env>(
@ -2524,6 +2717,29 @@ pub fn verify_fn(fn_val: FunctionValue<'_>) {
} }
} }
fn function_value_by_name<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
layout: &Layout<'a>,
symbol: Symbol,
) -> FunctionValue<'ctx> {
let fn_name = layout_ids
.get(symbol, layout)
.to_symbol_string(symbol, &env.interns);
let fn_name = fn_name.as_str();
env.module.get_function(fn_name).unwrap_or_else(|| {
if symbol.is_builtin() {
panic!("Unrecognized builtin function: {:?}", fn_name)
} else {
panic!(
"Unrecognized non-builtin function: {:?} (symbol: {:?}, layout: {:?})",
fn_name, symbol, layout
)
}
})
}
// #[allow(clippy::cognitive_complexity)] // #[allow(clippy::cognitive_complexity)]
#[inline(always)] #[inline(always)]
fn call_with_args<'a, 'ctx, 'env>( fn call_with_args<'a, 'ctx, 'env>(
@ -2534,21 +2750,7 @@ fn call_with_args<'a, 'ctx, 'env>(
_parent: FunctionValue<'ctx>, _parent: FunctionValue<'ctx>,
args: &[BasicValueEnum<'ctx>], args: &[BasicValueEnum<'ctx>],
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
let fn_name = layout_ids let fn_val = function_value_by_name(env, layout_ids, layout, symbol);
.get(symbol, layout)
.to_symbol_string(symbol, &env.interns);
let fn_name = fn_name.as_str();
let fn_val = env.module.get_function(fn_name).unwrap_or_else(|| {
if symbol.is_builtin() {
panic!("Unrecognized builtin function: {:?}", fn_name)
} else {
panic!(
"Unrecognized non-builtin function: {:?} (symbol: {:?}, layout: {:?})",
fn_name, symbol, layout
)
}
});
let call = env.builder.build_call(fn_val, args, "call"); let call = env.builder.build_call(fn_val, args, "call");
@ -3868,8 +4070,7 @@ fn cxa_throw_exception<'a, 'ctx, 'env>(env: &Env<'a, 'ctx, 'env>, info: BasicVal
call.set_call_convention(C_CALL_CONV); call.set_call_convention(C_CALL_CONV);
} }
#[allow(dead_code)] fn cxa_rethrow_exception(env: &Env<'_, '_, '_>) {
fn cxa_rethrow_exception<'a, 'ctx, 'env>(env: &Env<'a, 'ctx, 'env>) -> BasicValueEnum<'ctx> {
let name = "__cxa_rethrow"; let name = "__cxa_rethrow";
let module = env.module; let module = env.module;
@ -3888,10 +4089,10 @@ fn cxa_rethrow_exception<'a, 'ctx, 'env>(env: &Env<'a, 'ctx, 'env>) -> BasicValu
cxa_rethrow cxa_rethrow
} }
}; };
let call = env.builder.build_call(function, &[], "never_used"); let call = env.builder.build_call(function, &[], "rethrow");
call.set_call_convention(C_CALL_CONV); call.set_call_convention(C_CALL_CONV);
call.try_as_basic_value().left().unwrap() // call.try_as_basic_value().left().unwrap()
} }
fn get_foreign_symbol<'a, 'ctx, 'env>( fn get_foreign_symbol<'a, 'ctx, 'env>(

View file

@ -1739,4 +1739,18 @@ mod gen_list {
assert_evals_to!("[[1]] != [[1]]", false, bool); assert_evals_to!("[[1]] != [[1]]", false, bool);
assert_evals_to!("[[2]] != [[1]]", true, bool); assert_evals_to!("[[2]] != [[1]]", true, bool);
} }
#[should_panic(expected = r#"Roc failed with message: "integer addition overflowed!"#)]
fn cleanup_because_exception() {
assert_evals_to!(
indoc!(
r#"
x = [ 1,2 ]
5 + Num.maxInt + 3 + List.len x
"#
),
RocList::from_slice(&[false; 1]),
RocList<bool>
);
}
} }

View file

@ -82,6 +82,18 @@ where
self.free_symbols(stmt); self.free_symbols(stmt);
Ok(()) Ok(())
} }
Stmt::Invoke {
symbol,
layout,
call,
pass,
fail: _,
} => {
// for now, treat invoke as a normal call
let stmt = Stmt::Let(*symbol, Expr::Call(call.clone()), layout.clone(), pass);
self.build_stmt(&stmt)
}
x => Err(format!("the statement, {:?}, is not yet implemented", x)), x => Err(format!("the statement, {:?}, is not yet implemented", x)),
} }
} }
@ -103,26 +115,31 @@ where
} }
Ok(()) Ok(())
} }
Expr::FunctionCall { Expr::Call(roc_mono::ir::Call {
call_type: CallType::ByName(func_sym), call_type,
args, arguments,
.. }) => {
} => { match call_type {
match *func_sym { CallType::ByName { name: func_sym, .. } => {
Symbol::NUM_ABS => { match *func_sym {
// Instead of calling the function, just inline it. Symbol::NUM_ABS => {
self.build_expr(sym, &Expr::RunLowLevel(LowLevel::NumAbs, args), layout) // Instead of calling the function, just inline it.
self.build_run_low_level(sym, &LowLevel::NumAbs, arguments, layout)
}
Symbol::NUM_ADD => {
// Instead of calling the function, just inline it.
self.build_run_low_level(sym, &LowLevel::NumAdd, arguments, layout)
}
x => Err(format!("the function, {:?}, is not yet implemented", x)),
}
} }
Symbol::NUM_ADD => {
// Instead of calling the function, just inline it. CallType::LowLevel { op: lowlevel } => {
self.build_expr(sym, &Expr::RunLowLevel(LowLevel::NumAdd, args), layout) self.build_run_low_level(sym, lowlevel, arguments, layout)
} }
x => Err(format!("the function, {:?}, is not yet implemented", x)), x => Err(format!("the call type, {:?}, is not yet implemented", x)),
} }
} }
Expr::RunLowLevel(lowlevel, args) => {
self.build_run_low_level(sym, lowlevel, args, layout)
}
x => Err(format!("the expression, {:?}, is not yet implemented", x)), x => Err(format!("the expression, {:?}, is not yet implemented", x)),
} }
} }
@ -244,36 +261,9 @@ where
match expr { match expr {
Expr::Literal(_) => {} Expr::Literal(_) => {}
Expr::FunctionPointer(sym, _) => self.set_last_seen(*sym, stmt), Expr::FunctionPointer(sym, _) => self.set_last_seen(*sym, stmt),
Expr::FunctionCall {
call_type, args, .. Expr::Call(call) => self.scan_ast_call(call, stmt),
} => {
for sym in *args {
self.set_last_seen(*sym, stmt);
}
match call_type {
CallType::ByName(sym) => {
// For functions that we won't inline, we should not be a leaf function.
if !INLINED_SYMBOLS.contains(sym) {
self.set_not_leaf_function();
}
}
CallType::ByPointer(sym) => {
self.set_not_leaf_function();
self.set_last_seen(*sym, stmt);
}
}
}
Expr::RunLowLevel(_, args) => {
for sym in *args {
self.set_last_seen(*sym, stmt);
}
}
Expr::ForeignCall { arguments, .. } => {
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
self.set_not_leaf_function();
}
Expr::Tag { arguments, .. } => { Expr::Tag { arguments, .. } => {
for sym in *arguments { for sym in *arguments {
self.set_last_seen(*sym, stmt); self.set_last_seen(*sym, stmt);
@ -320,6 +310,20 @@ where
} }
self.scan_ast(following); self.scan_ast(following);
} }
Stmt::Invoke {
symbol,
layout,
call,
pass,
fail: _,
} => {
// for now, treat invoke as a normal call
let stmt = Stmt::Let(*symbol, Expr::Call(call.clone()), layout.clone(), pass);
self.scan_ast(&stmt);
}
Stmt::Switch { Stmt::Switch {
cond_symbol, cond_symbol,
branches, branches,
@ -335,6 +339,7 @@ where
Stmt::Ret(sym) => { Stmt::Ret(sym) => {
self.set_last_seen(*sym, stmt); self.set_last_seen(*sym, stmt);
} }
Stmt::Unreachable => {}
Stmt::Inc(sym, following) => { Stmt::Inc(sym, following) => {
self.set_last_seen(*sym, stmt); self.set_last_seen(*sym, stmt);
self.scan_ast(following); self.scan_ast(following);
@ -364,4 +369,30 @@ where
Stmt::RuntimeError(_) => {} Stmt::RuntimeError(_) => {}
} }
} }
fn scan_ast_call(&mut self, call: &roc_mono::ir::Call, stmt: &roc_mono::ir::Stmt<'a>) {
let roc_mono::ir::Call {
call_type,
arguments,
} = call;
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
match call_type {
CallType::ByName { name: sym, .. } => {
// For functions that we won't inline, we should not be a leaf function.
if !INLINED_SYMBOLS.contains(sym) {
self.set_not_leaf_function();
}
}
CallType::ByPointer { name: sym, .. } => {
self.set_not_leaf_function();
self.set_last_seen(*sym, stmt);
}
CallType::LowLevel { .. } => {}
CallType::Foreign { .. } => self.set_not_leaf_function(),
}
}
} }

View file

@ -1790,8 +1790,6 @@ fn update<'a>(
if state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations { if state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations {
debug_assert!(work.is_empty(), "still work remaining {:?}", &work); debug_assert!(work.is_empty(), "still work remaining {:?}", &work);
Proc::insert_refcount_operations(arena, &mut state.procedures);
// display the mono IR of the module, for debug purposes // display the mono IR of the module, for debug purposes
if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS { if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS {
let procs_string = state let procs_string = state
@ -1805,6 +1803,8 @@ fn update<'a>(
println!("{}", result); println!("{}", result);
} }
Proc::insert_refcount_operations(arena, &mut state.procedures);
msg_tx msg_tx
.send(Msg::FinishedAllSpecialization { .send(Msg::FinishedAllSpecialization {
subs, subs,

View file

@ -156,6 +156,10 @@ impl<'a> ParamMap<'a> {
Let(_, _, _, cont) => { Let(_, _, _, cont) => {
stack.push(cont); stack.push(cont);
} }
Invoke { pass, fail, .. } => {
stack.push(pass);
stack.push(fail);
}
Switch { Switch {
branches, branches,
default_branch, default_branch,
@ -166,7 +170,7 @@ impl<'a> ParamMap<'a> {
} }
Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"),
Ret(_) | Jump(_, _) | RuntimeError(_) => { Ret(_) | Unreachable | Jump(_, _) | RuntimeError(_) => {
// these are terminal, do nothing // these are terminal, do nothing
} }
} }
@ -295,6 +299,62 @@ impl<'a> BorrowInfState<'a> {
/// ///
/// and determines whether z and which of the symbols used in e /// and determines whether z and which of the symbols used in e
/// must be taken as owned paramters /// must be taken as owned paramters
fn collect_call(&mut self, z: Symbol, e: &crate::ir::Call<'a>) {
use crate::ir::CallType::*;
let crate::ir::Call {
call_type,
arguments,
} = e;
match call_type {
ByName {
name, arg_layouts, ..
}
| ByPointer {
name, arg_layouts, ..
} => {
// get the borrow signature of the applied function
let ps = match self.param_map.get_symbol(*name) {
Some(slice) => slice,
None => Vec::from_iter_in(
arg_layouts.iter().cloned().map(|layout| Param {
symbol: Symbol::UNDERSCORE,
borrow: false,
layout,
}),
self.arena,
)
.into_bump_slice(),
};
// the return value will be owned
self.own_var(z);
// if the function exects an owned argument (ps), the argument must be owned (args)
self.own_args_using_params(arguments, ps);
}
LowLevel { op } => {
// very unsure what demand RunLowLevel should place upon its arguments
self.own_var(z);
let ps = lowlevel_borrow_signature(self.arena, *op);
self.own_args_using_bools(arguments, ps);
}
Foreign { .. } => {
// very unsure what demand ForeignCall should place upon its arguments
self.own_var(z);
let ps = foreign_borrow_signature(self.arena, arguments.len());
self.own_args_using_bools(arguments, ps);
}
}
}
fn collect_expr(&mut self, z: Symbol, e: &Expr<'a>) { fn collect_expr(&mut self, z: Symbol, e: &Expr<'a>) {
use Expr::*; use Expr::*;
@ -334,73 +394,40 @@ impl<'a> BorrowInfState<'a> {
} }
} }
FunctionCall { Call(call) => self.collect_call(z, call),
call_type,
args,
arg_layouts,
..
} => {
// get the borrow signature of the applied function
let ps = match self.param_map.get_symbol(call_type.get_inner()) {
Some(slice) => slice,
None => Vec::from_iter_in(
arg_layouts.iter().cloned().map(|layout| Param {
symbol: Symbol::UNDERSCORE,
borrow: false,
layout,
}),
self.arena,
)
.into_bump_slice(),
};
// the return value will be owned
self.own_var(z);
// if the function exects an owned argument (ps), the argument must be owned (args)
self.own_args_using_params(args, ps);
}
RunLowLevel(op, args) => {
// very unsure what demand RunLowLevel should place upon its arguments
self.own_var(z);
let ps = lowlevel_borrow_signature(self.arena, *op);
self.own_args_using_bools(args, ps);
}
ForeignCall { arguments, .. } => {
// very unsure what demand ForeignCall should place upon its arguments
self.own_var(z);
let ps = foreign_borrow_signature(self.arena, arguments.len());
self.own_args_using_bools(arguments, ps);
}
Literal(_) | FunctionPointer(_, _) | RuntimeErrorFunction(_) => {} Literal(_) | FunctionPointer(_, _) | RuntimeErrorFunction(_) => {}
} }
} }
#[allow(clippy::many_single_char_names)]
fn preserve_tail_call(&mut self, x: Symbol, v: &Expr<'a>, b: &Stmt<'a>) { fn preserve_tail_call(&mut self, x: Symbol, v: &Expr<'a>, b: &Stmt<'a>) {
if let ( match (v, b) {
Expr::FunctionCall { (
call_type, Expr::Call(crate::ir::Call {
args: ys, call_type: crate::ir::CallType::ByName { name: g, .. },
.. arguments: ys,
}, ..
Stmt::Ret(z), }),
) = (v, b) Stmt::Ret(z),
{ )
let g = call_type.get_inner(); | (
if self.current_proc == g && x == *z { Expr::Call(crate::ir::Call {
// anonymous functions (for which the ps may not be known) call_type: crate::ir::CallType::ByPointer { name: g, .. },
// can never be tail-recursive, so this is fine arguments: ys,
if let Some(ps) = self.param_map.get_symbol(g) { ..
self.own_params_using_args(ys, ps) }),
Stmt::Ret(z),
) => {
if self.current_proc == *g && x == *z {
// anonymous functions (for which the ps may not be known)
// can never be tail-recursive, so this is fine
if let Some(ps) = self.param_map.get_symbol(*g) {
self.own_params_using_args(ys, ps)
}
} }
} }
_ => {}
} }
} }
@ -444,11 +471,29 @@ impl<'a> BorrowInfState<'a> {
self.collect_stmt(b); self.collect_stmt(b);
self.preserve_tail_call(*x, &Expr::FunctionPointer(*fsymbol, layout.clone()), b); self.preserve_tail_call(*x, &Expr::FunctionPointer(*fsymbol, layout.clone()), b);
} }
Let(x, v, _, b) => { Let(x, v, _, b) => {
self.collect_stmt(b); self.collect_stmt(b);
self.collect_expr(*x, v); self.collect_expr(*x, v);
self.preserve_tail_call(*x, v, b); self.preserve_tail_call(*x, v, b);
} }
Invoke {
symbol,
call,
layout: _,
pass,
fail,
} => {
self.collect_stmt(pass);
self.collect_stmt(fail);
self.collect_call(*symbol, call);
// TODO how to preserve the tail call of an invoke?
// self.preserve_tail_call(*x, v, b);
}
Jump(j, ys) => { Jump(j, ys) => {
let ps = self.param_map.get_join_point(*j); let ps = self.param_map.get_join_point(*j);
@ -470,7 +515,7 @@ impl<'a> BorrowInfState<'a> {
} }
Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"), Inc(_, _) | Dec(_, _) => unreachable!("these have not been introduced yet"),
Ret(_) | RuntimeError(_) => { Ret(_) | RuntimeError(_) | Unreachable => {
// these are terminal, do nothing // these are terminal, do nothing
} }
} }

View file

@ -1277,7 +1277,10 @@ fn compile_test<'a>(
ret_layout, ret_layout,
); );
let test = Expr::RunLowLevel(LowLevel::Eq, arena.alloc([lhs, rhs])); let test = Expr::Call(crate::ir::Call {
call_type: crate::ir::CallType::LowLevel { op: LowLevel::Eq },
arguments: arena.alloc([lhs, rhs]),
});
// write to the test symbol // write to the test symbol
cond = Stmt::Let( cond = Stmt::Let(

View file

@ -31,10 +31,27 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet<Symbol>, MutSet<Symbol>) {
bound_variables.insert(*symbol); bound_variables.insert(*symbol);
stack.push(cont); stack.push(cont);
} }
Invoke {
symbol,
call,
pass,
fail,
..
} => {
occuring_variables_call(call, &mut result);
result.insert(*symbol);
bound_variables.insert(*symbol);
stack.push(pass);
stack.push(fail);
}
Ret(symbol) => { Ret(symbol) => {
result.insert(*symbol); result.insert(*symbol);
} }
Unreachable => {}
Inc(symbol, cont) | Dec(symbol, cont) => { Inc(symbol, cont) | Dec(symbol, cont) => {
result.insert(*symbol); result.insert(*symbol);
stack.push(cont); stack.push(cont);
@ -75,6 +92,12 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet<Symbol>, MutSet<Symbol>) {
(result, bound_variables) (result, bound_variables)
} }
fn occuring_variables_call(call: &crate::ir::Call<'_>, result: &mut MutSet<Symbol>) {
// NOTE though the function name does occur, it is a static constant in the program
// for liveness, it should not be included here.
result.extend(call.arguments.iter().copied());
}
pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet<Symbol>) { pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet<Symbol>) {
use Expr::*; use Expr::*;
@ -86,11 +109,7 @@ pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet<Symbol>) {
result.insert(*symbol); result.insert(*symbol);
} }
FunctionCall { args, .. } => { Call(call) => occuring_variables_call(call, result),
// NOTE thouth the function name does occur, it is a static constant in the program
// for liveness, it should not be included here.
result.extend(args.iter().copied());
}
Tag { arguments, .. } Tag { arguments, .. }
| Struct(arguments) | Struct(arguments)
@ -108,12 +127,6 @@ pub fn occuring_variables_expr(expr: &Expr<'_>, result: &mut MutSet<Symbol>) {
Reset(x) => { Reset(x) => {
result.insert(*x); result.insert(*x);
} }
RunLowLevel(_, args) => {
result.extend(args.iter());
}
ForeignCall { arguments, .. } => {
result.extend(arguments.iter());
}
EmptyArray | RuntimeErrorFunction(_) | Literal(_) => {} EmptyArray | RuntimeErrorFunction(_) | Literal(_) => {}
} }
@ -208,6 +221,11 @@ fn consume_expr(m: &VarMap, e: &Expr<'_>) -> bool {
} }
} }
fn consume_call(_: &VarMap, _: &crate::ir::Call<'_>) -> bool {
// variables bound by a call (or invoke) must always be consumed
true
}
impl<'a> Context<'a> { impl<'a> Context<'a> {
pub fn new(arena: &'a Bump, param_map: &'a ParamMap<'a>) -> Self { pub fn new(arena: &'a Bump, param_map: &'a ParamMap<'a>) -> Self {
let mut vars = MutMap::default(); let mut vars = MutMap::default();
@ -410,6 +428,75 @@ impl<'a> Context<'a> {
b b
} }
fn visit_call(
&self,
z: Symbol,
call_type: crate::ir::CallType<'a>,
arguments: &'a [Symbol],
l: Layout<'a>,
b: &'a Stmt<'a>,
b_live_vars: &LiveVarSet,
) -> &'a Stmt<'a> {
use crate::ir::CallType::*;
match &call_type {
LowLevel { op } => {
let ps = crate::borrow::lowlevel_borrow_signature(self.arena, *op);
let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars);
let v = Expr::Call(crate::ir::Call {
call_type,
arguments,
});
&*self.arena.alloc(Stmt::Let(z, v, l, b))
}
Foreign { .. } => {
let ps = crate::borrow::foreign_borrow_signature(self.arena, arguments.len());
let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars);
let v = Expr::Call(crate::ir::Call {
call_type,
arguments,
});
&*self.arena.alloc(Stmt::Let(z, v, l, b))
}
ByName {
name, arg_layouts, ..
}
| ByPointer {
name, arg_layouts, ..
} => {
// get the borrow signature
let ps = match self.param_map.get_symbol(*name) {
Some(slice) => slice,
None => Vec::from_iter_in(
arg_layouts.iter().cloned().map(|layout| Param {
symbol: Symbol::UNDERSCORE,
borrow: false,
layout,
}),
self.arena,
)
.into_bump_slice(),
};
let v = Expr::Call(crate::ir::Call {
call_type,
arguments,
});
let b = self.add_dec_after_application(arguments, ps, b, b_live_vars);
let b = self.arena.alloc(Stmt::Let(z, v, l, b));
self.add_inc_before(arguments, ps, b, b_live_vars)
}
}
}
#[allow(clippy::many_single_char_names)] #[allow(clippy::many_single_char_names)]
fn visit_variable_declaration( fn visit_variable_declaration(
&self, &self,
@ -445,45 +532,10 @@ impl<'a> Context<'a> {
self.arena.alloc(Stmt::Let(z, v, l, b)) self.arena.alloc(Stmt::Let(z, v, l, b))
} }
RunLowLevel(op, args) => { Call(crate::ir::Call {
let ps = crate::borrow::lowlevel_borrow_signature(self.arena, op);
let b = self.add_dec_after_lowlevel(args, ps, b, b_live_vars);
self.arena.alloc(Stmt::Let(z, v, l, b))
}
ForeignCall { arguments, .. } => {
let ps = crate::borrow::foreign_borrow_signature(self.arena, arguments.len());
let b = self.add_dec_after_lowlevel(arguments, ps, b, b_live_vars);
self.arena.alloc(Stmt::Let(z, v, l, b))
}
FunctionCall {
args: ys,
arg_layouts,
call_type, call_type,
.. arguments,
} => { }) => self.visit_call(z, call_type, arguments, l, b, b_live_vars),
// get the borrow signature
let ps = match self.param_map.get_symbol(call_type.get_inner()) {
Some(slice) => slice,
None => Vec::from_iter_in(
arg_layouts.iter().cloned().map(|layout| Param {
symbol: Symbol::UNDERSCORE,
borrow: false,
layout,
}),
self.arena,
)
.into_bump_slice(),
};
let b = self.add_dec_after_application(ys, ps, b, b_live_vars);
let b = self.arena.alloc(Stmt::Let(z, v, l, b));
self.add_inc_before(ys, ps, b, b_live_vars)
}
EmptyArray EmptyArray
| FunctionPointer(_, _) | FunctionPointer(_, _)
@ -499,21 +551,45 @@ impl<'a> Context<'a> {
(new_b, live_vars) (new_b, live_vars)
} }
fn update_var_info_invoke(
&self,
symbol: Symbol,
layout: &Layout<'a>,
call: &crate::ir::Call<'a>,
) -> Self {
// is this value a constant?
// TODO do function pointers also fall into this category?
let persistent = call.arguments.is_empty();
// must this value be consumed?
let consume = consume_call(&self.vars, call);
self.update_var_info_help(symbol, layout, persistent, consume)
}
fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self { fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self {
let mut ctx = self.clone();
// can this type be reference-counted at runtime?
let reference = layout.contains_refcounted();
// is this value a constant? // is this value a constant?
// TODO do function pointers also fall into this category? // TODO do function pointers also fall into this category?
let persistent = match expr { let persistent = match expr {
Expr::FunctionCall { args, .. } => args.is_empty(), Expr::Call(crate::ir::Call { arguments, .. }) => arguments.is_empty(),
_ => false, _ => false,
}; };
// must this value be consumed? // must this value be consumed?
let consume = consume_expr(&ctx.vars, expr); let consume = consume_expr(&self.vars, expr);
self.update_var_info_help(symbol, layout, persistent, consume)
}
fn update_var_info_help(
&self,
symbol: Symbol,
layout: &Layout<'a>,
persistent: bool,
consume: bool,
) -> Self {
// can this type be reference-counted at runtime?
let reference = layout.contains_refcounted();
let info = VarInfo { let info = VarInfo {
reference, reference,
@ -521,6 +597,8 @@ impl<'a> Context<'a> {
consume, consume,
}; };
let mut ctx = self.clone();
ctx.vars.insert(symbol, info); ctx.vars.insert(symbol, info);
ctx ctx
@ -628,6 +706,47 @@ impl<'a> Context<'a> {
) )
} }
Invoke {
symbol,
call,
pass,
fail,
layout,
} => {
// TODO this combines parts of Let and Switch. Did this happen correctly?
let mut case_live_vars = collect_stmt(stmt, &self.jp_live_vars, MutSet::default());
case_live_vars.remove(symbol);
let fail = {
// TODO should we use ctor info like Lean?
let ctx = self.clone();
let (b, alt_live_vars) = ctx.visit_stmt(fail);
ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b)
};
case_live_vars.insert(*symbol);
let pass = {
// TODO should we use ctor info like Lean?
let ctx = self.clone();
let ctx = ctx.update_var_info_invoke(*symbol, layout, call);
let (b, alt_live_vars) = ctx.visit_stmt(pass);
ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b)
};
let invoke = Invoke {
symbol: *symbol,
call: call.clone(),
pass,
fail,
layout: layout.clone(),
};
let stmt = self.arena.alloc(invoke);
(stmt, case_live_vars)
}
Join { Join {
id: j, id: j,
parameters: _, parameters: _,
@ -673,6 +792,8 @@ impl<'a> Context<'a> {
} }
} }
Unreachable => (stmt, MutSet::default()),
Jump(j, xs) => { Jump(j, xs) => {
let empty = MutSet::default(); let empty = MutSet::default();
let j_live_vars = match self.jp_live_vars.get(j) { let j_live_vars = match self.jp_live_vars.get(j) {
@ -757,6 +878,25 @@ pub fn collect_stmt(
vars vars
} }
Invoke {
symbol,
call,
pass,
fail,
..
} => {
vars = collect_stmt(pass, jp_live_vars, vars);
vars = collect_stmt(fail, jp_live_vars, vars);
vars.remove(symbol);
let mut result = MutSet::default();
occuring_variables_call(call, &mut result);
vars.extend(result);
vars
}
Ret(symbol) => { Ret(symbol) => {
vars.insert(*symbol); vars.insert(*symbol);
vars vars
@ -813,6 +953,8 @@ pub fn collect_stmt(
vars vars
} }
Unreachable => vars,
RuntimeError(_) => vars, RuntimeError(_) => vars,
} }
} }

View file

@ -741,6 +741,13 @@ pub type Stores<'a> = &'a [(Symbol, Layout<'a>, Expr<'a>)];
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Stmt<'a> { pub enum Stmt<'a> {
Let(Symbol, Expr<'a>, Layout<'a>, &'a Stmt<'a>), Let(Symbol, Expr<'a>, Layout<'a>, &'a Stmt<'a>),
Invoke {
symbol: Symbol,
call: Call<'a>,
layout: Layout<'a>,
pass: &'a Stmt<'a>,
fail: &'a Stmt<'a>,
},
Switch { Switch {
/// This *must* stand for an integer, because Switch potentially compiles to a jump table. /// This *must* stand for an integer, because Switch potentially compiles to a jump table.
cond_symbol: Symbol, cond_symbol: Symbol,
@ -754,6 +761,7 @@ pub enum Stmt<'a> {
ret_layout: Layout<'a>, ret_layout: Layout<'a>,
}, },
Ret(Symbol), Ret(Symbol),
Unreachable,
Inc(Symbol, &'a Stmt<'a>), Inc(Symbol, &'a Stmt<'a>),
Dec(Symbol, &'a Stmt<'a>), Dec(Symbol, &'a Stmt<'a>),
Join { Join {
@ -784,20 +792,6 @@ pub enum Literal<'a> {
/// compile to bytes, e.g. [ Blue, Black, Red, Green, White ] /// compile to bytes, e.g. [ Blue, Black, Red, Green, White ]
Byte(u8), Byte(u8),
} }
#[derive(Clone, Debug, PartialEq, Copy)]
pub enum CallType {
ByName(Symbol),
ByPointer(Symbol),
}
impl CallType {
pub fn get_inner(&self) -> Symbol {
match self {
CallType::ByName(s) => *s,
CallType::ByPointer(s) => *s,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
pub enum Wrapped { pub enum Wrapped {
@ -837,25 +831,97 @@ impl Wrapped {
} }
} }
#[derive(Clone, Debug, PartialEq)]
pub struct Call<'a> {
pub call_type: CallType<'a>,
pub arguments: &'a [Symbol],
}
impl<'a> Call<'a> {
pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D) -> DocBuilder<'b, D, A>
where
D: DocAllocator<'b, A>,
D::Doc: Clone,
A: Clone,
{
use CallType::*;
let arguments = self.arguments;
match self.call_type {
CallType::ByName { name, .. } => {
let it = std::iter::once(name)
.chain(arguments.iter().copied())
.map(|s| symbol_to_doc(alloc, s));
alloc.text("CallByName ").append(alloc.intersperse(it, " "))
}
CallType::ByPointer { name, .. } => {
let it = std::iter::once(name)
.chain(arguments.iter().copied())
.map(|s| symbol_to_doc(alloc, s));
alloc
.text("CallByPointer ")
.append(alloc.intersperse(it, " "))
}
LowLevel { op: lowlevel } => {
let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s));
alloc
.text(format!("lowlevel {:?} ", lowlevel))
.append(alloc.intersperse(it, " "))
}
Foreign {
ref foreign_symbol, ..
} => {
let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s));
alloc
.text(format!("foreign {:?} ", foreign_symbol.as_str()))
.append(alloc.intersperse(it, " "))
}
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum CallType<'a> {
ByName {
name: Symbol,
full_layout: Layout<'a>,
ret_layout: Layout<'a>,
arg_layouts: &'a [Layout<'a>],
},
ByPointer {
name: Symbol,
full_layout: Layout<'a>,
ret_layout: Layout<'a>,
arg_layouts: &'a [Layout<'a>],
},
Foreign {
foreign_symbol: ForeignSymbol,
ret_layout: Layout<'a>,
},
LowLevel {
op: LowLevel,
},
}
// x = f a b c; S
//
//
// invoke x = f a b c in S else Unreachable
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Expr<'a> { pub enum Expr<'a> {
Literal(Literal<'a>), Literal(Literal<'a>),
// Functions // Functions
FunctionPointer(Symbol, Layout<'a>), FunctionPointer(Symbol, Layout<'a>),
FunctionCall { Call(Call<'a>),
call_type: CallType,
full_layout: Layout<'a>,
ret_layout: Layout<'a>,
arg_layouts: &'a [Layout<'a>],
args: &'a [Symbol],
},
RunLowLevel(LowLevel, &'a [Symbol]),
ForeignCall {
foreign_symbol: ForeignSymbol,
arguments: &'a [Symbol],
ret_layout: Layout<'a>,
},
Tag { Tag {
tag_layout: Layout<'a>, tag_layout: Layout<'a>,
@ -956,44 +1022,8 @@ impl<'a> Expr<'a> {
.text("FunctionPointer ") .text("FunctionPointer ")
.append(symbol_to_doc(alloc, *symbol)), .append(symbol_to_doc(alloc, *symbol)),
FunctionCall { Call(call) => call.to_doc(alloc),
call_type, args, ..
} => match call_type {
CallType::ByName(name) => {
let it = std::iter::once(name)
.chain(args.iter())
.map(|s| symbol_to_doc(alloc, *s));
alloc.text("CallByName ").append(alloc.intersperse(it, " "))
}
CallType::ByPointer(name) => {
let it = std::iter::once(name)
.chain(args.iter())
.map(|s| symbol_to_doc(alloc, *s));
alloc
.text("CallByPointer ")
.append(alloc.intersperse(it, " "))
}
},
RunLowLevel(lowlevel, args) => {
let it = args.iter().map(|s| symbol_to_doc(alloc, *s));
alloc
.text(format!("lowlevel {:?} ", lowlevel))
.append(alloc.intersperse(it, " "))
}
ForeignCall {
foreign_symbol,
arguments,
..
} => {
let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s));
alloc
.text(format!("foreign {:?} ", foreign_symbol.as_str()))
.append(alloc.intersperse(it, " "))
}
Tag { Tag {
tag_name, tag_name,
arguments, arguments,
@ -1098,11 +1128,45 @@ impl<'a> Stmt<'a> {
.append(alloc.hardline()) .append(alloc.hardline())
.append(cont.to_doc(alloc)), .append(cont.to_doc(alloc)),
Invoke {
symbol,
call,
pass,
fail: Stmt::Unreachable,
..
} => alloc
.text("let ")
.append(symbol_to_doc(alloc, *symbol))
.append(" = ")
.append(call.to_doc(alloc))
.append(";")
.append(alloc.hardline())
.append(pass.to_doc(alloc)),
Invoke {
symbol,
call,
pass,
fail,
..
} => alloc
.text("invoke ")
.append(symbol_to_doc(alloc, *symbol))
.append(" = ")
.append(call.to_doc(alloc))
.append(" catch")
.append(alloc.hardline())
.append(fail.to_doc(alloc).indent(4))
.append(alloc.hardline())
.append(pass.to_doc(alloc)),
Ret(symbol) => alloc Ret(symbol) => alloc
.text("ret ") .text("ret ")
.append(symbol_to_doc(alloc, *symbol)) .append(symbol_to_doc(alloc, *symbol))
.append(";"), .append(";"),
Unreachable => alloc.text("unreachable;"),
Switch { Switch {
cond_symbol, cond_symbol,
branches, branches,
@ -3506,13 +3570,15 @@ pub fn with_hole<'a>(
// build the call // build the call
result = Stmt::Let( result = Stmt::Let(
assigned, assigned,
Expr::FunctionCall { Expr::Call(self::Call {
call_type: CallType::ByPointer(closure_function_symbol), call_type: CallType::ByPointer {
full_layout: function_ptr_layout.clone(), name: closure_function_symbol,
ret_layout: ret_layout.clone(), full_layout: function_ptr_layout.clone(),
args: arg_symbols, ret_layout: ret_layout.clone(),
arg_layouts, arg_layouts,
}, },
arguments: arg_symbols,
}),
ret_layout, ret_layout,
arena.alloc(hole), arena.alloc(hole),
); );
@ -3553,13 +3619,15 @@ pub fn with_hole<'a>(
} else { } else {
result = Stmt::Let( result = Stmt::Let(
assigned, assigned,
Expr::FunctionCall { Expr::Call(self::Call {
call_type: CallType::ByPointer(function_symbol), call_type: CallType::ByPointer {
full_layout, name: function_symbol,
ret_layout: ret_layout.clone(), full_layout,
args: arg_symbols, ret_layout: ret_layout.clone(),
arg_layouts, arg_layouts,
}, },
arguments: arg_symbols,
}),
ret_layout, ret_layout,
arena.alloc(hole), arena.alloc(hole),
); );
@ -3570,13 +3638,15 @@ pub fn with_hole<'a>(
result = Stmt::Let( result = Stmt::Let(
assigned, assigned,
Expr::FunctionCall { Expr::Call(self::Call {
call_type: CallType::ByPointer(function_symbol), call_type: CallType::ByPointer {
full_layout, name: function_symbol,
ret_layout: ret_layout.clone(), full_layout,
args: arg_symbols, ret_layout: ret_layout.clone(),
arg_layouts, arg_layouts,
}, },
arguments: arg_symbols,
}),
ret_layout, ret_layout,
arena.alloc(hole), arena.alloc(hole),
); );
@ -3615,16 +3685,15 @@ pub fn with_hole<'a>(
.from_var(env.arena, ret_var, env.subs) .from_var(env.arena, ret_var, env.subs)
.unwrap_or_else(|err| todo!("TODO turn fn_var into a RuntimeError {:?}", err)); .unwrap_or_else(|err| todo!("TODO turn fn_var into a RuntimeError {:?}", err));
let result = Stmt::Let( let call = self::Call {
assigned, call_type: CallType::Foreign {
Expr::ForeignCall {
foreign_symbol, foreign_symbol,
arguments: arg_symbols,
ret_layout: layout.clone(), ret_layout: layout.clone(),
}, },
layout, arguments: arg_symbols,
hole, };
);
let result = build_call(env, call, assigned, layout, hole);
let iter = args let iter = args
.into_iter() .into_iter()
@ -3663,7 +3732,12 @@ pub fn with_hole<'a>(
} }
}; };
let result = Stmt::Let(assigned, Expr::RunLowLevel(op, arg_symbols), layout, hole); let call = self::Call {
call_type: CallType::LowLevel { op },
arguments: arg_symbols,
};
let result = build_call(env, call, assigned, layout, hole);
let iter = args let iter = args
.into_iter() .into_iter()
@ -4324,6 +4398,33 @@ fn substitute_in_stmt_help<'a>(
None None
} }
} }
Invoke {
symbol,
call,
layout,
pass,
fail,
} => {
let opt_call = substitute_in_call(arena, call, subs);
let opt_pass = substitute_in_stmt_help(arena, pass, subs);
let opt_fail = substitute_in_stmt_help(arena, fail, subs);
if opt_pass.is_some() || opt_fail.is_some() | opt_call.is_some() {
let pass = opt_pass.unwrap_or(pass);
let fail = opt_fail.unwrap_or_else(|| *fail);
let call = opt_call.unwrap_or_else(|| call.clone());
Some(arena.alloc(Invoke {
symbol: *symbol,
call,
layout: layout.clone(),
pass,
fail,
}))
} else {
None
}
}
Join { Join {
id, id,
parameters, parameters,
@ -4436,10 +4537,75 @@ fn substitute_in_stmt_help<'a>(
} }
} }
Unreachable => None,
RuntimeError(_) => None, RuntimeError(_) => None,
} }
} }
fn substitute_in_call<'a>(
arena: &'a Bump,
call: &'a Call<'a>,
subs: &MutMap<Symbol, Symbol>,
) -> Option<Call<'a>> {
let Call {
call_type,
arguments,
} = call;
let opt_call_type = match call_type {
CallType::ByName {
name,
arg_layouts,
ret_layout,
full_layout,
} => substitute(subs, *name).map(|new| CallType::ByName {
name: new,
arg_layouts,
ret_layout: ret_layout.clone(),
full_layout: full_layout.clone(),
}),
CallType::ByPointer {
name,
arg_layouts,
ret_layout,
full_layout,
} => substitute(subs, *name).map(|new| CallType::ByPointer {
name: new,
arg_layouts,
ret_layout: ret_layout.clone(),
full_layout: full_layout.clone(),
}),
CallType::Foreign { .. } => None,
CallType::LowLevel { .. } => None,
};
let mut did_change = false;
let new_args = Vec::from_iter_in(
arguments.iter().map(|s| match substitute(subs, *s) {
None => *s,
Some(s) => {
did_change = true;
s
}
}),
arena,
);
if did_change || opt_call_type.is_some() {
let call_type = opt_call_type.unwrap_or_else(|| call_type.clone());
let arguments = new_args.into_bump_slice();
Some(self::Call {
call_type,
arguments,
})
} else {
None
}
}
fn substitute_in_expr<'a>( fn substitute_in_expr<'a>(
arena: &'a Bump, arena: &'a Bump,
expr: &'a Expr<'a>, expr: &'a Expr<'a>,
@ -4450,96 +4616,7 @@ fn substitute_in_expr<'a>(
match expr { match expr {
Literal(_) | FunctionPointer(_, _) | EmptyArray | RuntimeErrorFunction(_) => None, Literal(_) | FunctionPointer(_, _) | EmptyArray | RuntimeErrorFunction(_) => None,
FunctionCall { Call(call) => substitute_in_call(arena, call, subs).map(Expr::Call),
call_type,
args,
arg_layouts,
ret_layout,
full_layout,
} => {
let opt_call_type = match call_type {
CallType::ByName(s) => substitute(subs, *s).map(CallType::ByName),
CallType::ByPointer(s) => substitute(subs, *s).map(CallType::ByPointer),
};
let mut did_change = false;
let new_args = Vec::from_iter_in(
args.iter().map(|s| match substitute(subs, *s) {
None => *s,
Some(s) => {
did_change = true;
s
}
}),
arena,
);
if did_change || opt_call_type.is_some() {
let call_type = opt_call_type.unwrap_or(*call_type);
let args = new_args.into_bump_slice();
Some(FunctionCall {
call_type,
args,
arg_layouts: *arg_layouts,
ret_layout: ret_layout.clone(),
full_layout: full_layout.clone(),
})
} else {
None
}
}
RunLowLevel(op, args) => {
let mut did_change = false;
let new_args = Vec::from_iter_in(
args.iter().map(|s| match substitute(subs, *s) {
None => *s,
Some(s) => {
did_change = true;
s
}
}),
arena,
);
if did_change {
let args = new_args.into_bump_slice();
Some(RunLowLevel(*op, args))
} else {
None
}
}
ForeignCall {
foreign_symbol,
arguments,
ret_layout,
} => {
let mut did_change = false;
let new_args = Vec::from_iter_in(
arguments.iter().map(|s| match substitute(subs, *s) {
None => *s,
Some(s) => {
did_change = true;
s
}
}),
arena,
);
if did_change {
let args = new_args.into_bump_slice();
Some(ForeignCall {
foreign_symbol: foreign_symbol.clone(),
arguments: args,
ret_layout: ret_layout.clone(),
})
} else {
None
}
}
Tag { Tag {
tag_layout, tag_layout,
@ -5145,6 +5222,55 @@ fn add_needed_external<'a>(
existing.insert(name, solved_type); existing.insert(name, solved_type);
} }
fn can_throw_exception(call: &Call) -> bool {
match call.call_type {
CallType::ByName { name, .. } => matches!(
name,
Symbol::NUM_ADD
| Symbol::NUM_SUB
| Symbol::NUM_MUL
| Symbol::NUM_DIV_FLOAT
| Symbol::NUM_ABS
| Symbol::NUM_NEG
),
CallType::ByPointer { .. } => {
// we don't know what we're calling; it might throw, so better be safe than sorry
true
}
CallType::Foreign { .. } => {
// calling foreign functions is very unsafe
true
}
CallType::LowLevel { .. } => {
// lowlevel operations themselves don't throw
false
}
}
}
fn build_call<'a>(
env: &mut Env<'a, '_>,
call: Call<'a>,
assigned: Symbol,
layout: Layout<'a>,
hole: &'a Stmt<'a>,
) -> Stmt<'a> {
if can_throw_exception(&call) {
let fail = env.arena.alloc(Stmt::Unreachable);
Stmt::Invoke {
symbol: assigned,
call,
layout,
fail,
pass: hole,
}
} else {
Stmt::Let(assigned, Expr::Call(call), layout, hole)
}
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn call_by_name<'a>( fn call_by_name<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
@ -5206,15 +5332,17 @@ fn call_by_name<'a>(
"see call_by_name for background (scroll down a bit)" "see call_by_name for background (scroll down a bit)"
); );
let call = Expr::FunctionCall { let call = self::Call {
call_type: CallType::ByName(proc_name), call_type: CallType::ByName {
ret_layout: ret_layout.clone(), name: proc_name,
full_layout: full_layout.clone(), ret_layout: ret_layout.clone(),
arg_layouts, full_layout: full_layout.clone(),
args: field_symbols, arg_layouts,
},
arguments: field_symbols,
}; };
let result = Stmt::Let(assigned, call, ret_layout.clone(), hole); let result = build_call(env, call, assigned, ret_layout.clone(), hole);
let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev()); let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev());
assign_to_symbols(env, procs, layout_cache, iter, result) assign_to_symbols(env, procs, layout_cache, iter, result)
@ -5254,17 +5382,20 @@ fn call_by_name<'a>(
field_symbols.len(), field_symbols.len(),
"see call_by_name for background (scroll down a bit)" "see call_by_name for background (scroll down a bit)"
); );
let call = Expr::FunctionCall {
call_type: CallType::ByName(proc_name), let call = self::Call {
ret_layout: ret_layout.clone(), call_type: CallType::ByName {
full_layout: full_layout.clone(), name: proc_name,
arg_layouts, ret_layout: ret_layout.clone(),
args: field_symbols, full_layout: full_layout.clone(),
arg_layouts,
},
arguments: field_symbols,
}; };
let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev()); let result = build_call(env, call, assigned, ret_layout.clone(), hole);
let result = Stmt::Let(assigned, call, ret_layout.clone(), hole); let iter = loc_args.into_iter().rev().zip(field_symbols.iter().rev());
assign_to_symbols(env, procs, layout_cache, iter, result) assign_to_symbols(env, procs, layout_cache, iter, result)
} }
None => { None => {
@ -5314,12 +5445,18 @@ fn call_by_name<'a>(
// and we have to fix it here. // and we have to fix it here.
match full_layout { match full_layout {
Layout::Closure(_, closure_layout, _) => { Layout::Closure(_, closure_layout, _) => {
let call = Expr::FunctionCall { let call = self::Call {
call_type: CallType::ByName(proc_name), call_type: CallType::ByName {
ret_layout: function_layout.result.clone(), name: proc_name,
full_layout: function_layout.full.clone(), ret_layout: function_layout
arg_layouts: function_layout.arguments, .result
args: field_symbols, .clone(),
full_layout: function_layout
.full
.clone(),
arg_layouts: function_layout.arguments,
},
arguments: field_symbols,
}; };
// in the case of a closure specifically, we // in the case of a closure specifically, we
@ -5333,25 +5470,33 @@ fn call_by_name<'a>(
]), ]),
); );
Stmt::Let( build_call(
assigned, env,
call, call,
assigned,
closure_struct_layout, closure_struct_layout,
hole, hole,
) )
} }
_ => { _ => {
let call = Expr::FunctionCall { let call = self::Call {
call_type: CallType::ByName(proc_name), call_type: CallType::ByName {
ret_layout: function_layout.result.clone(), name: proc_name,
full_layout: function_layout.full.clone(), ret_layout: function_layout
arg_layouts: function_layout.arguments, .result
args: field_symbols, .clone(),
full_layout: function_layout
.full
.clone(),
arg_layouts: function_layout.arguments,
},
arguments: field_symbols,
}; };
Stmt::Let( build_call(
assigned, env,
call, call,
assigned,
function_layout.full, function_layout.full,
hole, hole,
) )
@ -5363,12 +5508,14 @@ fn call_by_name<'a>(
field_symbols.len(), field_symbols.len(),
"scroll up a bit for background" "scroll up a bit for background"
); );
let call = Expr::FunctionCall { let call = self::Call {
call_type: CallType::ByName(proc_name), call_type: CallType::ByName {
ret_layout: function_layout.result.clone(), name: proc_name,
full_layout: function_layout.full, ret_layout: function_layout.result.clone(),
arg_layouts: function_layout.arguments, full_layout: function_layout.full.clone(),
args: field_symbols, arg_layouts: function_layout.arguments,
},
arguments: field_symbols,
}; };
let iter = loc_args let iter = loc_args
@ -5376,9 +5523,10 @@ fn call_by_name<'a>(
.rev() .rev()
.zip(field_symbols.iter().rev()); .zip(field_symbols.iter().rev());
let result = Stmt::Let( let result = build_call(
assigned, env,
call, call,
assigned,
function_layout.result, function_layout.result,
hole, hole,
); );
@ -5415,18 +5563,22 @@ fn call_by_name<'a>(
"scroll up a bit for background" "scroll up a bit for background"
); );
let call = Expr::FunctionCall { let call = self::Call {
call_type: CallType::ByName(proc_name), call_type: CallType::ByName {
ret_layout: ret_layout.clone(), name: proc_name,
full_layout: full_layout.clone(), ret_layout: ret_layout.clone(),
arg_layouts, full_layout: full_layout.clone(),
args: field_symbols, arg_layouts,
},
arguments: field_symbols,
}; };
let result =
build_call(env, call, assigned, ret_layout.clone(), hole);
let iter = let iter =
loc_args.into_iter().rev().zip(field_symbols.iter().rev()); loc_args.into_iter().rev().zip(field_symbols.iter().rev());
let result = Stmt::Let(assigned, call, ret_layout.clone(), hole);
assign_to_symbols(env, procs, layout_cache, iter, result) assign_to_symbols(env, procs, layout_cache, iter, result)
} }

View file

@ -75,17 +75,38 @@ fn insert_jumps<'a>(
match stmt { match stmt {
Let( Let(
symbol, symbol,
Expr::FunctionCall { Expr::Call(crate::ir::Call {
call_type: CallType::ByName(fsym), call_type: CallType::ByName { name: fsym, .. },
args, arguments,
.. ..
}, }),
_, _,
Stmt::Ret(rsym), Stmt::Ret(rsym),
) if needle == *fsym && symbol == rsym => { ) if needle == *fsym && symbol == rsym => {
// replace the call and return with a jump // replace the call and return with a jump
let jump = Stmt::Jump(goal_id, args); let jump = Stmt::Jump(goal_id, arguments);
Some(arena.alloc(jump))
}
Invoke {
symbol,
call:
crate::ir::Call {
call_type: CallType::ByName { name: fsym, .. },
arguments,
..
},
fail,
pass: Stmt::Ret(rsym),
..
} if needle == *fsym && symbol == rsym => {
debug_assert_eq!(fail, &&Stmt::Unreachable);
// replace the call and return with a jump
let jump = Stmt::Jump(goal_id, arguments);
Some(arena.alloc(jump)) Some(arena.alloc(jump))
} }
@ -101,6 +122,35 @@ fn insert_jumps<'a>(
None None
} }
} }
Invoke {
symbol,
call,
fail,
pass,
layout,
} => {
let opt_pass = insert_jumps(arena, pass, goal_id, needle);
let opt_fail = insert_jumps(arena, fail, goal_id, needle);
if opt_pass.is_some() || opt_fail.is_some() {
let pass = opt_pass.unwrap_or(pass);
let fail = opt_fail.unwrap_or(fail);
let stmt = Invoke {
symbol: *symbol,
call: call.clone(),
layout: layout.clone(),
pass,
fail,
};
Some(arena.alloc(stmt))
} else {
None
}
}
Join { Join {
id, id,
parameters, parameters,
@ -187,6 +237,7 @@ fn insert_jumps<'a>(
None => None, None => None,
}, },
Unreachable => None,
Ret(_) => None, Ret(_) => None,
Jump(_, _) => None, Jump(_, _) => None,
RuntimeError(_) => None, RuntimeError(_) => None,

View file

@ -161,6 +161,45 @@ mod test_mono {
) )
} }
#[test]
fn ir_int_add() {
compiles_to_ir(
r#"
x = [ 1,2 ]
5 + 4 + 3 + List.len x
"#,
indoc!(
r#"
procedure List.7 (#Attr.2):
let Test.6 = lowlevel ListLen #Attr.2;
ret Test.6;
procedure Num.24 (#Attr.2, #Attr.3):
let Test.5 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.5;
procedure Test.0 ():
let Test.11 = 1i64;
let Test.12 = 2i64;
let Test.1 = Array [Test.11, Test.12];
let Test.9 = 5i64;
let Test.10 = 4i64;
invoke Test.7 = CallByName Num.24 Test.9 Test.10 catch
dec Test.1;
unreachable;
let Test.8 = 3i64;
invoke Test.3 = CallByName Num.24 Test.7 Test.8 catch
dec Test.1;
unreachable;
let Test.4 = CallByName List.7 Test.1;
dec Test.1;
let Test.2 = CallByName Num.24 Test.3 Test.4;
ret Test.2;
"#
),
)
}
#[test] #[test]
fn ir_assignment() { fn ir_assignment() {
compiles_to_ir( compiles_to_ir(