diff --git a/compiler/builtins/bitcode/src/list.zig b/compiler/builtins/bitcode/src/list.zig index 2df42ae9e5..380f0f7902 100644 --- a/compiler/builtins/bitcode/src/list.zig +++ b/compiler/builtins/bitcode/src/list.zig @@ -7,6 +7,7 @@ const Allocator = mem.Allocator; const TAG_WIDTH = 8; const EqFn = fn (?[*]u8, ?[*]u8) callconv(.C) bool; +const CompareFn = fn (?[*]u8, ?[*]u8, ?[*]u8) callconv(.C) u8; const Opaque = ?[*]u8; const Inc = fn (?[*]u8) callconv(.C) void; @@ -688,3 +689,7 @@ fn listRangeHelp(allocator: *Allocator, comptime T: type, low: T, high: T) RocLi }, } } + +pub fn listSortWith(list: RocList, transform: Opaque, wrapper: CompareFn, alignment: usize, element_width: usize) callconv(.C) RocList { + return list; +} diff --git a/compiler/builtins/bitcode/src/main.zig b/compiler/builtins/bitcode/src/main.zig index 2661c7c1b0..64b344d7b6 100644 --- a/compiler/builtins/bitcode/src/main.zig +++ b/compiler/builtins/bitcode/src/main.zig @@ -20,6 +20,7 @@ comptime { exportListFn(list.listRepeat, "repeat"); exportListFn(list.listAppend, "append"); exportListFn(list.listRange, "range"); + exportListFn(list.listSortWith, "sort_with"); } // Dict Module diff --git a/compiler/gen/src/llvm/bitcode.rs b/compiler/gen/src/llvm/bitcode.rs index fe7b037500..c918bb01dc 100644 --- a/compiler/gen/src/llvm/bitcode.rs +++ b/compiler/gen/src/llvm/bitcode.rs @@ -2,11 +2,15 @@ use crate::debug_info_init; use crate::llvm::build::{set_name, Env, FAST_CALL_CONV}; use crate::llvm::convert::basic_type_from_layout; use crate::llvm::refcounting::{decrement_refcount_layout, increment_refcount_layout, Mode}; -use inkwell::attributes::{Attribute, AttributeLoc}; +use either::Either; /// Helpers for interacting with the zig that generates bitcode use inkwell::types::{BasicType, BasicTypeEnum}; use inkwell::values::{BasicValueEnum, CallSiteValue, FunctionValue, InstructionValue}; use inkwell::AddressSpace; +use inkwell::{ + attributes::{Attribute, AttributeLoc}, + values::PointerValue, +}; use roc_module::symbol::Symbol; use roc_mono::layout::{Layout, LayoutIds}; @@ -383,3 +387,95 @@ pub fn build_eq_wrapper<'a, 'ctx, 'env>( function_value } + +pub fn build_compare_wrapper<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + layout: &Layout<'a>, +) -> FunctionValue<'ctx> { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let symbol = Symbol::COMPARE_REF; + let fn_name = layout_ids + .get(symbol, &layout) + .to_symbol_string(symbol, &env.interns); + + let function_value = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let arg_type = env.context.i8_type().ptr_type(AddressSpace::Generic); + + let function_value = crate::llvm::refcounting::build_header_help( + env, + &fn_name, + env.context.i8_type().into(), + &[arg_type.into(), arg_type.into(), arg_type.into()], + ); + + let kind_id = Attribute::get_named_enum_kind_id("alwaysinline"); + debug_assert!(kind_id > 0); + let attr = env.context.create_enum_attribute(kind_id, 1); + function_value.add_attribute(AttributeLoc::Function, attr); + + let entry = env.context.append_basic_block(function_value, "entry"); + env.builder.position_at_end(entry); + + debug_info_init!(env, function_value); + + let mut it = function_value.get_param_iter(); + let function_ptr = it.next().unwrap().into_pointer_value(); + let value_ptr1 = it.next().unwrap().into_pointer_value(); + let value_ptr2 = it.next().unwrap().into_pointer_value(); + + set_name( + function_ptr.into(), + Symbol::ARG_1.ident_string(&env.interns), + ); + set_name(value_ptr1.into(), Symbol::ARG_2.ident_string(&env.interns)); + set_name(value_ptr2.into(), Symbol::ARG_3.ident_string(&env.interns)); + + let value_type = basic_type_from_layout(env.arena, env.context, layout, env.ptr_bytes); + let function_type = env + .context + .i8_type() + .fn_type(&[value_type, value_type], false) + .ptr_type(AddressSpace::Generic); + let value_ptr_type = value_type.ptr_type(AddressSpace::Generic); + + let function_cast = + env.builder + .build_bitcast(function_ptr, function_type, "load_opaque"); + let value_cast1 = env + .builder + .build_bitcast(value_ptr1, value_ptr_type, "load_opaque") + .into_pointer_value(); + + let value_cast2 = env + .builder + .build_bitcast(value_ptr2, value_ptr_type, "load_opaque") + .into_pointer_value(); + + let value1 = env.builder.build_load(value_cast1, "load_opaque"); + let value2 = env.builder.build_load(value_cast2, "load_opaque"); + + let call = env.builder.build_call( + function_cast.into_pointer_value(), + &[value1, value2], + "call_user_defined_function", + ); + // call.set_call_convention(user_defined_function.get_call_conventions()); + let result = call.try_as_basic_value().left().unwrap(); + + env.builder.build_return(Some(&result)); + + function_value + } + }; + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + function_value +} diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index c6ab07cf0d..8a5f2e0c8a 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -7,8 +7,8 @@ use crate::llvm::build_hash::generic_hash; use crate::llvm::build_list::{ allocate_list, empty_list, empty_polymorphic_list, list_append, list_concat, list_contains, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map, - list_map2, list_map3, list_map_with_index, list_prepend, list_repeat, list_reverse, list_set, - list_single, list_sort_with, + list_map2, list_map3, list_map_with_index, list_prepend, list_range, list_repeat, list_reverse, + list_set, list_single, list_sort_with, list_walk_help, }; use crate::llvm::build_str::{ str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, str_from_utf8, diff --git a/compiler/gen/src/llvm/build_list.rs b/compiler/gen/src/llvm/build_list.rs index 4d446d2184..7b60eba5eb 100644 --- a/compiler/gen/src/llvm/build_list.rs +++ b/compiler/gen/src/llvm/build_list.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use crate::llvm::bitcode::{ - build_dec_wrapper, build_eq_wrapper, build_inc_wrapper, build_transform_caller, - call_bitcode_fn, call_void_bitcode_fn, + build_compare_wrapper, build_dec_wrapper, build_eq_wrapper, build_inc_wrapper, + build_transform_caller, call_bitcode_fn, call_void_bitcode_fn, }; use crate::llvm::build::{ allocate_with_refcount_help, cast_basic_basic, complex_bitcast, Env, InPlace, @@ -1141,40 +1141,23 @@ pub fn list_sort_with<'a, 'ctx, 'env>( list: BasicValueEnum<'ctx>, element_layout: &Layout<'a>, ) -> BasicValueEnum<'ctx> { - // TODO: decide between returning void pointer or u8 from function passed in. - // TODO: implement soriting in zig let builder = env.builder; - let return_layout = match transform_layout { - Layout::FunctionPointer(_, ret) => ret, - Layout::Closure(_, _, ret) => ret, - _ => unreachable!("not a callable layout"), - }; - - let u8_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); + let u9_ptr = env.context.i8_type().ptr_type(AddressSpace::Generic); let list_i128 = complex_bitcast(env.builder, list, env.context.i128_type().into(), "to_i128"); let transform_ptr = builder.build_alloca(transform.get_type(), "transform_ptr"); env.builder.build_store(transform_ptr, transform); - let stepper_caller = build_transform_caller( - env, - layout_ids, - transform_layout, - &[*element_layout, *element_layout], - ) - .as_global_value() - .as_pointer_value(); + let compare_wrapper = build_compare_wrapper(env, layout_ids, element_layout) + .as_global_value() + .as_pointer_value(); - let old_element_width = env + let element_width = env .ptr_int() .const_int(element_layout.stack_size(env.ptr_bytes) as u64, false); - let new_element_width = env - .ptr_int() - .const_int(return_layout.stack_size(env.ptr_bytes) as u64, false); - let alignment = element_layout.alignment_bytes(env.ptr_bytes); let alignment_iv = env.ptr_int().const_int(alignment as u64, false); @@ -1184,10 +1167,9 @@ pub fn list_sort_with<'a, 'ctx, 'env>( list_i128, env.builder .build_bitcast(transform_ptr, u8_ptr, "to_opaque"), - stepper_caller.into(), + compare_wrapper.into(), alignment_iv.into(), - old_element_width.into(), - new_element_width.into(), + element_width.into(), ], bitcode::LIST_SORT_WITH, ); diff --git a/compiler/module/src/symbol.rs b/compiler/module/src/symbol.rs index 1c55736d81..227461207a 100644 --- a/compiler/module/src/symbol.rs +++ b/compiler/module/src/symbol.rs @@ -756,6 +756,8 @@ define_builtins! { // A caller (wrapper) that we pass to zig for it to be able to call Roc functions 20 ZIG_FUNCTION_CALLER: "#zig_function_caller" + + 21 COMPARE_REF: "#compare_ref" // TODO: <- a nice comment } 1 NUM: "Num" => { 0 NUM_NUM: "Num" imported // the Num.Num type alias