diff --git a/compiler/builtins/src/unique.rs b/compiler/builtins/src/unique.rs index 61110152c7..bdf5cf22e1 100644 --- a/compiler/builtins/src/unique.rs +++ b/compiler/builtins/src/unique.rs @@ -787,14 +787,14 @@ pub fn types() -> MutMap { // , Attr Shared (a -> Attr * Bool) // -> Attr * (List b) add_type(Symbol::LIST_KEEP_IF, { - let_tvars! { a, b, star1, star2, star3 }; + let_tvars! { a, star1, star2, star3 }; unique_function( vec![ list_type(star1, a), shared(SolvedType::Func(vec![flex(a)], Box::new(bool_type(star2)))), ], - list_type(star3, b), + list_type(star3, a), ) }); diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 05698953e0..b62853c654 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -285,6 +285,7 @@ pub fn list_join<'a, 'ctx, 'env>( // inner_list_len > 0 let inner_list_comparison = list_is_not_empty(builder, ctx, inner_list_len); + let inner_list_non_empty_block = ctx.append_basic_block(parent, "inner_list_non_empty"); let after_inner_list_non_empty_block = @@ -649,7 +650,144 @@ pub fn list_keep_if<'a, 'ctx, 'env>( list: BasicValueEnum<'ctx>, list_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - empty_list(env) + match (func, func_layout) { + ( + BasicValueEnum::PointerValue(func_ptr), + Layout::FunctionPointer(_, Layout::Builtin(Builtin::Int1)), + ) => { + let non_empty_fn = |elem_layout: &Layout<'a>, + len: IntValue<'ctx>, + list_wrapper: StructValue<'ctx>| { + let ctx = env.context; + let builder = env.builder; + + let elem_type = basic_type_from_layout(env.arena, ctx, elem_layout, env.ptr_bytes); + let elem_ptr_type = get_ptr_type(&elem_type, AddressSpace::Generic); + + let list_ptr = load_list_ptr(builder, list_wrapper, elem_ptr_type); + + let ret_list_len_name = "#ret_list_alloca"; + let ret_list_len_alloca = builder.build_alloca(ctx.i64_type(), ret_list_len_name); + builder.build_store(ret_list_len_alloca, ctx.i64_type().const_int(0, false)); + + // Return List Length Loop + let ret_list_len_loop = |_, elem: BasicValueEnum<'ctx>| { + let call_site_value = + builder.build_call(func_ptr, env.arena.alloc([elem]), "map_func"); + + // set the calling convention explicitly for this call + call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); + + let should_keep = call_site_value + .try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) + .into_int_value(); + + let loop_bb = ctx.append_basic_block(parent, "loop"); + let after_bb = ctx.append_basic_block(parent, "after_loop"); + + builder.build_conditional_branch(should_keep, loop_bb, after_bb); + builder.position_at_end(loop_bb); + + { + let next_ret_list_len = builder.build_int_add( + builder + .build_load(ret_list_len_alloca, ret_list_len_name) + .into_int_value(), + ctx.i64_type().const_int(1, false), + "next_ret_list_len", + ); + + builder.build_store(ret_list_len_alloca, next_ret_list_len); + } + + builder.build_unconditional_branch(after_bb); + builder.position_at_end(after_bb); + }; + + let index_alloca = incrementing_elem_loop( + builder, + parent, + ctx, + LoopListArg { ptr: list_ptr, len }, + "#index", + None, + ret_list_len_loop, + ); + builder.build_store(index_alloca, ctx.i64_type().const_int(0, false)); + + let final_ret_list_len = builder + .build_load(ret_list_len_alloca, ret_list_len_name) + .into_int_value(); + let ret_list_ptr = allocate_list(env, elem_layout, final_ret_list_len); + + let dest_elem_ptr_alloca = builder.build_alloca(elem_ptr_type, "dest_elem"); + builder.build_store(dest_elem_ptr_alloca, ret_list_ptr); + + let list_loop = |_, elem| { + let call_site_value = + builder.build_call(func_ptr, env.arena.alloc([elem]), "map_func"); + + // set the calling convention explicitly for this call + call_site_value.set_call_convention(crate::llvm::build::FAST_CALL_CONV); + + let should_keep = call_site_value + .try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by pointer.")) + .into_int_value(); + + let loop_bb = ctx.append_basic_block(parent, "loop"); + let after_bb = ctx.append_basic_block(parent, "after_loop"); + + builder.build_conditional_branch(should_keep, loop_bb, after_bb); + builder.position_at_end(loop_bb); + + { + let dest_elem_ptr = builder + .build_load(dest_elem_ptr_alloca, "load_dest_elem_ptr") + .into_pointer_value(); + + builder.build_store(dest_elem_ptr, elem); + + let inc_dest_elem_ptr = BasicValueEnum::PointerValue(unsafe { + builder.build_in_bounds_gep( + dest_elem_ptr, + &[env.ptr_int().const_int(1 as u64, false)], + "increment_dest_elem", + ) + }); + + builder.build_store(dest_elem_ptr_alloca, inc_dest_elem_ptr); + } + + builder.build_unconditional_branch(after_bb); + builder.position_at_end(after_bb); + }; + + incrementing_elem_loop( + builder, + parent, + ctx, + LoopListArg { ptr: list_ptr, len }, + "#index", + Some(index_alloca), + list_loop, + ); + + store_list(env, ret_list_ptr, final_ret_list_len) + }; + + if_list_is_not_empty(env, parent, non_empty_fn, list, list_layout, "List.keepIf") + } + _ => { + unreachable!( + "Invalid function basic value enum or layout for List.keepIf : {:?}", + (func, func_layout) + ); + } + } } /// List.map : List before, (before -> after) -> List after @@ -677,8 +815,6 @@ pub fn list_map<'a, 'ctx, 'env>( let list_ptr = load_list_ptr(builder, list_wrapper, ptr_type); let list_loop = |index, before_elem| { - // The pointer to the element in the input list - let call_site_value = builder.build_call(func_ptr, env.arena.alloc([before_elem]), "map_func"); diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index 3721e2ecab..b5d8fc49c0 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -136,7 +136,12 @@ mod gen_list { assert_evals_to!( indoc!( r#" - List.keepIf [] (\x -> True) + alwaysTrue : Int -> Bool + alwaysTrue = \_ -> + True + + + List.keepIf [] alwaysTrue "# ), &[], @@ -144,6 +149,40 @@ mod gen_list { ); } + #[test] + fn list_keep_if_always_false_for_non_empty_list() { + assert_evals_to!( + indoc!( + r#" + alwaysFalse : Int -> Bool + alwaysFalse = \i -> + False + + List.keepIf [1,2,3,4,5,6,7,8] alwaysFalse + "# + ), + &[], + &'static [i64] + ); + } + + #[test] + fn list_keep_if_one() { + assert_evals_to!( + indoc!( + r#" + intIsOne : Int -> Bool + intIsOne = \i -> + False + + List.keepIf [1,2,3,4,5,6,7,8] intIsOne + "# + ), + &[1], + &'static [i64] + ); + } + #[test] fn list_map_on_empty_list_with_int_layout() { assert_evals_to!(