diff --git a/compiler/builtins/src/std.rs b/compiler/builtins/src/std.rs index a8d7a27913..68e5de7a47 100644 --- a/compiler/builtins/src/std.rs +++ b/compiler/builtins/src/std.rs @@ -722,6 +722,15 @@ pub fn types() -> MutMap { ), ); + // product : List (Num a) -> Num a + add_type( + Symbol::LIST_PRODUCT, + top_level_function( + vec![list_type(num_type(flex(TVAR1)))], + Box::new(num_type(flex(TVAR1))), + ), + ); + // walk : List elem, (elem -> accum -> accum), accum -> accum add_type( Symbol::LIST_WALK, diff --git a/compiler/can/src/builtins.rs b/compiler/can/src/builtins.rs index fee0598286..f393b276b6 100644 --- a/compiler/can/src/builtins.rs +++ b/compiler/can/src/builtins.rs @@ -77,6 +77,7 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option LIST_CONCAT => list_concat, LIST_CONTAINS => list_contains, LIST_SUM => list_sum, + LIST_PRODUCT => list_product, LIST_PREPEND => list_prepend, LIST_JOIN => list_join, LIST_MAP => list_map, @@ -217,6 +218,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap { Symbol::LIST_CONCAT => list_concat, Symbol::LIST_CONTAINS => list_contains, Symbol::LIST_SUM => list_sum, + Symbol::LIST_PRODUCT => list_product, Symbol::LIST_PREPEND => list_prepend, Symbol::LIST_JOIN => list_join, Symbol::LIST_MAP => list_map, @@ -2116,22 +2118,12 @@ fn list_walk_backwards(symbol: Symbol, var_store: &mut VarStore) -> Def { /// List.sum : List (Num a) -> Num a fn list_sum(symbol: Symbol, var_store: &mut VarStore) -> Def { - let list_var = var_store.fresh(); - let result_var = var_store.fresh(); + lowlevel_1(symbol, LowLevel::ListSum, var_store) +} - let body = RunLowLevel { - op: LowLevel::ListSum, - args: vec![(list_var, Var(Symbol::ARG_1))], - ret_var: result_var, - }; - - defn( - symbol, - vec![(list_var, Symbol::ARG_1)], - var_store, - body, - result_var, - ) +/// List.product : List (Num a) -> Num a +fn list_product(symbol: Symbol, var_store: &mut VarStore) -> Def { + lowlevel_1(symbol, LowLevel::ListProduct, var_store) } /// List.keepIf : List elem, (elem -> Bool) -> List elem diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index 6494512630..494e5f072d 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -7,8 +7,8 @@ use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, - list_map2, list_map3, list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, - list_single, list_sum, list_walk, list_walk_backwards, + list_map2, list_map3, list_map_with_index, list_prepend, list_product, list_repeat, + list_reverse, list_set, list_single, list_sum, list_walk, list_walk_backwards, }; use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8, @@ -3930,6 +3930,13 @@ fn run_low_level<'a, 'ctx, 'env>( list_sum(env, parent, list, layout) } + ListProduct => { + debug_assert_eq!(args.len(), 1); + + let list = load_symbol(scope, &args[0]); + + list_product(env, parent, list, layout) + } ListAppend => { // List.append : List elem, elem -> List elem debug_assert_eq!(args.len(), 2); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 6842b81dfe..40321129db 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -788,6 +788,81 @@ pub fn list_sum<'a, 'ctx, 'env>( builder.build_load(accum_alloca, "load_final_acum") } +/// List.product : List (Num a) -> Num a +pub fn list_product<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + list: BasicValueEnum<'ctx>, + default_layout: &Layout<'a>, +) -> BasicValueEnum<'ctx> { + let ctx = env.context; + let builder = env.builder; + + let list_wrapper = list.into_struct_value(); + let len = list_len(env.builder, list_wrapper); + + let accum_type = basic_type_from_layout(env.arena, ctx, default_layout, env.ptr_bytes); + let accum_alloca = builder.build_alloca(accum_type, "alloca_walk_right_accum"); + + let default: BasicValueEnum = match accum_type { + BasicTypeEnum::IntType(int_type) => int_type.const_int(1, false).into(), + BasicTypeEnum::FloatType(float_type) => float_type.const_float(1.0).into(), + _ => unreachable!(""), + }; + + builder.build_store(accum_alloca, default); + + let then_block = ctx.append_basic_block(parent, "then"); + let cont_block = ctx.append_basic_block(parent, "branchcont"); + + let condition = builder.build_int_compare( + IntPredicate::UGT, + len, + ctx.i64_type().const_zero(), + "list_non_empty", + ); + + builder.build_conditional_branch(condition, then_block, cont_block); + + builder.position_at_end(then_block); + + let elem_ptr_type = get_ptr_type(&accum_type, AddressSpace::Generic); + let list_ptr = load_list_ptr(builder, list_wrapper, elem_ptr_type); + + let walk_right_loop = |_, elem: BasicValueEnum<'ctx>| { + // load current accumulator + let current = builder.build_load(accum_alloca, "retrieve_accum"); + + let new_current = build_num_binop( + env, + parent, + current, + default_layout, + elem, + default_layout, + roc_module::low_level::LowLevel::NumMul, + ); + + builder.build_store(accum_alloca, new_current); + }; + + incrementing_elem_loop( + builder, + ctx, + parent, + list_ptr, + len, + "#index", + walk_right_loop, + ); + + builder.build_unconditional_branch(cont_block); + + builder.position_at_end(cont_block); + + builder.build_load(accum_alloca, "load_final_acum") +} + /// List.walk : List elem, (elem -> accum -> accum), accum -> accum pub fn list_walk<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, diff --git a/compiler/module/src/low_level.rs b/compiler/module/src/low_level.rs index a84ce73ea3..876ad9edda 100644 --- a/compiler/module/src/low_level.rs +++ b/compiler/module/src/low_level.rs @@ -34,6 +34,7 @@ pub enum LowLevel { ListWalk, ListWalkBackwards, ListSum, + ListProduct, ListKeepOks, ListKeepErrs, DictSize, diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 3320c3b40b..b270a405d6 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -913,6 +913,7 @@ define_builtins! { 23 LIST_MAP_WITH_INDEX: "mapWithIndex" 24 LIST_MAP2: "map2" 25 LIST_MAP3: "map3" + 26 LIST_PRODUCT: "product" } 5 RESULT: "Result" => { 0 RESULT_RESULT: "Result" imported // the Result.Result type alias diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index b78f7a0131..4f44fb49c2 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -657,7 +657,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] { ListContains => arena.alloc_slice_copy(&[borrowed, irrelevant]), ListWalk => arena.alloc_slice_copy(&[owned, irrelevant, owned]), ListWalkBackwards => arena.alloc_slice_copy(&[owned, irrelevant, owned]), - ListSum => arena.alloc_slice_copy(&[borrowed]), + ListSum | ListProduct => arena.alloc_slice_copy(&[borrowed]), // TODO when we have lists with capacity (if ever) // List.append should own its first argument diff --git a/compiler/test_gen/src/gen_list.rs b/compiler/test_gen/src/gen_list.rs index 2ed0c4fbe4..5c256299ba 100644 --- a/compiler/test_gen/src/gen_list.rs +++ b/compiler/test_gen/src/gen_list.rs @@ -1760,6 +1760,13 @@ fn list_sum() { assert_evals_to!("List.sum [ 1.1, 2.2, 3.3 ]", 6.6, f64); } +#[test] +fn list_product() { + assert_evals_to!("List.product []", 1, i64); + assert_evals_to!("List.product [ 1, 2, 3 ]", 6, i64); + assert_evals_to!("List.product [ 1.1, 2.2, 3.3 ]", 1.1 * 2.2 * 3.3, f64); +} + #[test] fn list_keep_oks() { assert_evals_to!("List.keepOks [] (\\x -> x)", 0, i64);