From 632d4eca9214f1f9760bc4f3779911dfdcad91cd Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Mon, 20 Apr 2020 23:02:14 -0400 Subject: [PATCH] Thread ret_layout through CallByName --- compiler/gen/src/llvm/build.rs | 88 +++++++++++++++++++++++++++++- compiler/mono/src/decision_tree.rs | 9 ++- compiler/mono/src/expr.rs | 16 ++++-- compiler/mono/tests/test_mono.rs | 16 +++++- compiler/mono/tests/test_opt.rs | 5 +- 5 files changed, 121 insertions(+), 13 deletions(-) diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index d4c32193f6..ff10859ab5 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -188,7 +188,7 @@ pub fn build_expr<'a, 'ctx, 'env>( build_expr(env, &scope, parent, ret, procs) } - CallByName(symbol, args) => match *symbol { + CallByName(symbol, args, ret_layout) => match *symbol { Symbol::BOOL_OR => { // The (||) operator debug_assert!(args.len() == 2); @@ -236,7 +236,13 @@ pub fn build_expr<'a, 'ctx, 'env>( arg_tuples.push((build_expr(env, scope, parent, arg, procs), layout)); } - call_with_args(*symbol, parent, arg_tuples.into_bump_slice(), env) + call_with_args( + *symbol, + parent, + arg_tuples.into_bump_slice(), + env, + ret_layout, + ) } }, FunctionPointer(symbol) => { @@ -873,6 +879,7 @@ fn call_with_args<'a, 'ctx, 'env>( parent: FunctionValue<'ctx>, args: &[(BasicValueEnum<'ctx>, &'a Layout<'a>)], env: &Env<'a, 'ctx, 'env>, + ret_layout: &'a Layout<'a>, ) -> BasicValueEnum<'ctx> { match symbol { Symbol::INT_ADD | Symbol::NUM_ADD => { @@ -1072,6 +1079,7 @@ fn call_with_args<'a, 'ctx, 'env>( } } } + Symbol::LIST_GET => list_get(parent, args, env, ret_layout), Symbol::LIST_SET => list_set(parent, args, env, InPlace::Clone), Symbol::LIST_SET_IN_PLACE => list_set(parent, args, env, InPlace::InPlace), _ => { @@ -1195,6 +1203,82 @@ fn bounds_check_comparison<'ctx>( builder.build_int_compare(IntPredicate::ULT, elem_index, len, "bounds_check") } +fn list_get<'a, 'ctx, 'env>( + parent: FunctionValue<'ctx>, + args: &[(BasicValueEnum<'ctx>, &'a Layout<'a>)], + env: &Env<'a, 'ctx, 'env>, + ret_layout: &'a Layout<'a>, +) -> BasicValueEnum<'ctx> { + // List.get : List elem, Int -> [Ok elem, OutOfBounds]* + debug_assert!(args.len() == 2); + + let builder = env.builder; + let original_wrapper = args[0].0.into_struct_value(); + let list_layout = args[0].1; + let elem_index = args[1].0.into_int_value(); + + // Load the usize length from the wrapper. We need it for bounds checking. + let list_len = load_list_len(builder, original_wrapper); + + // Bounds check: only proceed if index < length. + // Otherwise, return the list unaltered. + let comparison = bounds_check_comparison(builder, elem_index, list_len); + + // If the index is in bounds, wrap the result in Ok + let build_then = || { + match list_layout { + Layout::Builtin(Builtin::List(_)) => { + // Load the pointer to the array data + let array_data_ptr = load_list_ptr(builder, original_wrapper); + + // We already checked the bounds earlier. + let elem_ptr = + unsafe { builder.build_in_bounds_gep(array_data_ptr, &[elem_index], "elem") }; + + match ret_layout { + Layout::Union(tags) => { + // TODO wrap this in an Ok. + builder.build_load(elem_ptr, "List.get") + } + _ => { + unreachable!( + "List.get did not return a tag union somehow {:?}", + list_layout + ); + } + } + } + _ => { + unreachable!("Invalid List layout for List.get: {:?}", list_layout); + } + } + }; + + // If the index was out of bounds, return OutOfBounds + let build_else = || { + // let layout = Layout::Union(layouts.into_bump_slice()); + + // Expr::Tag { + // tag_layout: layout, + // tag_name, + // tag_id: tag_id as u8, + // union_size, + // arguments: arguments.into_bump_slice(), + // } + BasicValueEnum::StructValue(original_wrapper) + }; + let ret_type = original_wrapper.get_type(); + + build_basic_phi2( + env, + parent, + comparison, + build_then, + build_else, + ret_type.into(), + ) +} + fn list_set<'a, 'ctx, 'env>( parent: FunctionValue<'ctx>, args: &[(BasicValueEnum<'ctx>, &'a Layout<'a>)], diff --git a/compiler/mono/src/decision_tree.rs b/compiler/mono/src/decision_tree.rs index af00476c12..5c9e5d0992 100644 --- a/compiler/mono/src/decision_tree.rs +++ b/compiler/mono/src/decision_tree.rs @@ -1119,7 +1119,7 @@ fn decide_to_branching<'a>( let fail = (fail_stores, &*env.arena.alloc(fail_expr)); let pass = (pass_stores, &*env.arena.alloc(pass_expr)); - let condition = boolean_all(env.arena, tests); + let condition = boolean_all(env.arena, tests, &ret_layout); let branch_symbol = env.fresh_symbol(); let stores = [(branch_symbol, Layout::Builtin(Builtin::Bool), condition)]; @@ -1199,7 +1199,11 @@ fn decide_to_branching<'a>( } } -fn boolean_all<'a>(arena: &'a Bump, tests: Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>) -> Expr<'a> { +fn boolean_all<'a>( + arena: &'a Bump, + tests: Vec<(Expr<'a>, Expr<'a>, Layout<'a>)>, + ret_layout: &Layout<'a>, +) -> Expr<'a> { let mut expr = Expr::Bool(true); for (lhs, rhs, layout) in tests.into_iter().rev() { @@ -1210,6 +1214,7 @@ fn boolean_all<'a>(arena: &'a Bump, tests: Vec<(Expr<'a>, Expr<'a>, Layout<'a>)> (test, Layout::Builtin(Builtin::Bool)), (expr, Layout::Builtin(Builtin::Bool)), ]), + ret_layout.clone(), ); } diff --git a/compiler/mono/src/expr.rs b/compiler/mono/src/expr.rs index f59d67d01b..04969bed75 100644 --- a/compiler/mono/src/expr.rs +++ b/compiler/mono/src/expr.rs @@ -138,7 +138,7 @@ pub enum Expr<'a> { // Functions FunctionPointer(Symbol), - CallByName(Symbol, &'a [(Expr<'a>, Layout<'a>)]), + CallByName(Symbol, &'a [(Expr<'a>, Layout<'a>)], Layout<'a>), CallByPointer(&'a Expr<'a>, &'a [Expr<'a>], Layout<'a>), // Exactly two conditional branches, e.g. if/else @@ -1303,7 +1303,7 @@ fn call_by_name<'a>( Some(specialization) => { opt_specialize_body = None; - // a specialization with this type hash already exists, use its symbol + // a specialization with this type hash already exists, so use its symbol specialization.0 } None => { @@ -1354,13 +1354,18 @@ fn call_by_name<'a>( let mut args = Vec::with_capacity_in(loc_args.len(), env.arena); for (var, loc_arg) in loc_args { - let layout = Layout::from_var(&env.arena, var, &env.subs, env.pointer_size) - .unwrap_or_else(|err| panic!("TODO gracefully handle bad layout: {:?}", err)); + let layout = + Layout::from_var(&env.arena, var, &env.subs, env.pointer_size).unwrap_or_else(|err| { + panic!("TODO gracefully handle bad function arg layout: {:?}", err) + }); args.push((from_can(env, loc_arg.value, procs, None), layout)); } - Expr::CallByName(specialized_proc_name, args.into_bump_slice()) + let ret_layout = Layout::from_var(&env.arena, ret_var, &env.subs, env.pointer_size) + .unwrap_or_else(|err| panic!("TODO gracefully handle bad function ret layout: {:?}", err)); + + Expr::CallByName(specialized_proc_name, args.into_bump_slice(), ret_layout) } #[allow(clippy::too_many_arguments)] @@ -1704,5 +1709,6 @@ pub fn specialize_equality<'a>( Expr::CallByName( symbol, arena.alloc([(lhs, layout.clone()), (rhs, layout.clone())]), + Layout::Builtin(Builtin::Bool), ) } diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 8269fb347c..f9b7d19254 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -90,6 +90,7 @@ mod test_mono { (Float(3.0), Layout::Builtin(Builtin::Float64)), (Float(4.0), Layout::Builtin(Builtin::Float64)), ], + Layout::Builtin(Builtin::Float64), ), ); } @@ -104,6 +105,7 @@ mod test_mono { (Int(3735928559), Layout::Builtin(Builtin::Int64)), (Int(4), Layout::Builtin(Builtin::Int64)), ], + Layout::Builtin(Builtin::Int64), ), ); } @@ -119,6 +121,7 @@ mod test_mono { (Int(3), Layout::Builtin(Builtin::Int64)), (Int(5), Layout::Builtin(Builtin::Int64)), ], + Layout::Builtin(Builtin::Int64), ), ); } @@ -141,11 +144,15 @@ mod test_mono { Struct(&[ ( - CallByName(gen_symbol_3, &[(Int(4), Builtin(Int64))]), + CallByName(gen_symbol_3, &[(Int(4), Builtin(Int64))], Builtin(Int64)), Builtin(Int64), ), ( - CallByName(gen_symbol_4, &[(Float(3.14), Builtin(Float64))]), + CallByName( + gen_symbol_4, + &[(Float(3.14), Builtin(Float64))], + Builtin(Float64), + ), Builtin(Float64), ), ]) @@ -324,11 +331,12 @@ mod test_mono { gen_symbol_3, &[( Struct(&[( - CallByName(gen_symbol_4, &[(Int(4), Builtin(Int64))]), + CallByName(gen_symbol_4, &[(Int(4), Builtin(Int64))], Builtin(Int64)), Builtin(Int64), )]), Layout::Struct(&[Builtin(Int64)]), )], + Layout::Struct(&[Builtin(Int64)]), ) }, ) @@ -483,11 +491,13 @@ mod test_mono { (Int(1), Layout::Builtin(Builtin::Int64)), (Int(42), Layout::Builtin(Builtin::Int64)), ], + Layout::Builtin(Builtin::List(&Layout::Builtin(Builtin::Int64))), ), Layout::Builtin(Builtin::List(&Layout::Builtin(Builtin::Int64))), ), (Int(1), Layout::Builtin(Builtin::Int64)), ], + Layout::Builtin(Builtin::Int64), ) }); } diff --git a/compiler/mono/tests/test_opt.rs b/compiler/mono/tests/test_opt.rs index 122a9d57e2..550083d3e1 100644 --- a/compiler/mono/tests/test_opt.rs +++ b/compiler/mono/tests/test_opt.rs @@ -73,6 +73,7 @@ mod test_opt { unexpected_calls } + fn extract_named_calls_help( expr: &Expr<'_>, calls: &mut Vec, @@ -98,7 +99,7 @@ mod test_opt { } } - CallByName(symbol, args) => { + CallByName(symbol, args, _ret_layout) => { // Search for the symbol. If we found it, check it off the list. // If we didn't find it, add it to the list of unexpected calls. match calls.binary_search(symbol) { @@ -241,11 +242,13 @@ mod test_opt { (Int(1), Layout::Builtin(Builtin::Int64)), (Int(42), Layout::Builtin(Builtin::Int64)), ], + Layout::Builtin(Builtin::List(&Layout::Builtin(Builtin::Int64))), ), Layout::Builtin(Builtin::List(&Layout::Builtin(Builtin::Int64))), ), (Int(1), Layout::Builtin(Builtin::Int64)), ], + Layout::Builtin(Builtin::Int64), ), ); }