Wasm: Move Eq/NotEq into LowLevelCall

This commit is contained in:
Brian Carroll 2022-01-17 09:07:12 +00:00
parent f635dd8776
commit f354b4842b
2 changed files with 292 additions and 329 deletions

View file

@ -1,7 +1,7 @@
use bumpalo::{self, collections::Vec};
use code_builder::Align;
use roc_builtins::bitcode::{self, IntWidth};
use roc_builtins::bitcode::IntWidth;
use roc_collections::all::MutMap;
use roc_module::ident::Ident;
use roc_module::low_level::{LowLevel, LowLevelWrapperType};
@ -15,9 +15,9 @@ use roc_mono::ir::{
use roc_mono::layout::{Builtin, Layout, LayoutIds, TagIdIntType, UnionLayout};
use roc_reporting::internal_error;
use crate::layout::{CallConv, ReturnMethod, StackMemoryFormat, WasmLayout};
use crate::layout::{CallConv, ReturnMethod, WasmLayout};
use crate::low_level::LowLevelCall;
use crate::storage::{StackMemoryLocation, Storage, StoredValue, StoredValueKind};
use crate::storage::{Storage, StoredValue, StoredValueKind};
use crate::wasm_module::linking::{DataSymbol, LinkingSegment, WasmObjectSymbol};
use crate::wasm_module::sections::{DataMode, DataSegment};
use crate::wasm_module::{
@ -809,36 +809,14 @@ impl<'a> WasmBackend<'a> {
ret_layout: &Layout<'a>,
ret_storage: &StoredValue,
) {
use LowLevel::*;
let wasm_layout = WasmLayout::new(ret_layout);
match lowlevel {
Eq | NotEq => self.build_eq_or_neq(
lowlevel,
arguments,
ret_symbol,
wasm_layout,
ret_layout,
ret_storage,
),
PtrCast => {
// Don't want Zig calling convention when casting pointers.
self.storage.load_symbols(&mut self.code_builder, arguments);
}
Hash => todo!("Generic hash function generation"),
// Almost all lowlevels take this branch, except for the special cases above
_ => {
let low_level_call = LowLevelCall {
lowlevel,
arguments,
ret_symbol,
ret_layout: ret_layout.to_owned(),
ret_storage: ret_storage.to_owned(),
};
low_level_call.generate(self);
}
}
let low_level_call = LowLevelCall {
lowlevel,
arguments,
ret_symbol,
ret_layout: ret_layout.to_owned(),
ret_storage: ret_storage.to_owned(),
};
low_level_call.generate(self);
}
/// Generate a call instruction to a Zig builtin function.
@ -847,11 +825,9 @@ impl<'a> WasmBackend<'a> {
pub fn call_zig_builtin_after_loading_args(
&mut self,
name: &'a str,
param_types: Vec<'a, ValueType>,
ret_type: Option<ValueType>,
num_wasm_args: usize,
has_return_val: bool,
) {
let num_wasm_args = param_types.len();
let has_return_val = ret_type.is_some();
let fn_index = self.module.names.functions[name.as_bytes()];
self.called_preload_fns.push(fn_index);
let linker_symbol_index = u32::MAX;
@ -860,6 +836,43 @@ impl<'a> WasmBackend<'a> {
.call(fn_index, linker_symbol_index, num_wasm_args, has_return_val);
}
/// Call a helper procedure that implements `==` for a data structure (not numbers or Str)
/// If this is the first call for this Layout, it will generate the IR for the procedure.
/// Call stack is expr_call_low_level -> LowLevelCall::generate -> call_eq_specialized
/// It's a bit circuitous, but the alternative is to give low_level.rs `pub` access to
/// interns, helper_proc_gen, and expr(). That just seemed all wrong.
pub fn call_eq_specialized(
&mut self,
arguments: &'a [Symbol],
arg_layout: &Layout<'a>,
ret_symbol: Symbol,
ret_storage: &StoredValue,
) {
let ident_ids = self
.interns
.all_ident_ids
.get_mut(&self.env.module_id)
.unwrap();
// Get an IR expression for the call to the specialized procedure
let (specialized_call_expr, new_specializations) = self
.helper_proc_gen
.call_specialized_equals(ident_ids, arg_layout, arguments);
// If any new specializations were created, register their symbol data
for spec in new_specializations.into_iter() {
self.register_helper_proc(spec);
}
// Generate Wasm code for the IR call expression
self.expr(
ret_symbol,
self.env.arena.alloc(specialized_call_expr),
&Layout::Builtin(Builtin::Bool),
ret_storage,
);
}
/*******************************************************************
* Structs
*******************************************************************/
@ -967,9 +980,7 @@ impl<'a> WasmBackend<'a> {
self.code_builder.i32_const(alignment_bytes as i32);
// Call the foreign function. (Zig and C calling conventions are the same for this signature)
let param_types = bumpalo::vec![in self.env.arena; ValueType::I32, ValueType::I32];
let ret_type = Some(ValueType::I32);
self.call_zig_builtin_after_loading_args("roc_alloc", param_types, ret_type);
self.call_zig_builtin_after_loading_args("roc_alloc", 2, true);
// Save the allocation address to a temporary local variable
let local_id = self.storage.create_anonymous_local(ValueType::I32);
@ -1310,232 +1321,4 @@ impl<'a> WasmBackend<'a> {
self.storage
.copy_value_from_memory(&mut self.code_builder, symbol, from_ptr, from_offset);
}
/*******************************************************************
* Equality
*******************************************************************/
fn build_eq_or_neq(
&mut self,
lowlevel: LowLevel,
arguments: &'a [Symbol],
return_sym: Symbol,
return_layout: WasmLayout,
mono_layout: &Layout<'a>,
storage: &StoredValue,
) {
let arg_layout = self.storage.symbol_layouts[&arguments[0]];
let other_arg_layout = self.storage.symbol_layouts[&arguments[1]];
debug_assert!(
arg_layout == other_arg_layout,
"Cannot do `==` comparison on different types"
);
match arg_layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => self.build_eq_or_neq_number(lowlevel, arguments, return_layout, mono_layout),
Layout::Builtin(Builtin::Str) => {
let (param_types, ret_type) = self.storage.load_symbols_for_call(
self.env.arena,
&mut self.code_builder,
arguments,
return_sym,
&return_layout,
CallConv::Zig,
);
self.call_zig_builtin_after_loading_args(bitcode::STR_EQUAL, param_types, ret_type);
if matches!(lowlevel, LowLevel::NotEq) {
self.code_builder.i32_eqz();
}
}
// Empty record is always equal to empty record.
// There are no runtime arguments to check, so just emit true or false.
Layout::Struct(fields) if fields.is_empty() => {
self.code_builder
.i32_const(if lowlevel == LowLevel::Eq { 1 } else { 0 });
}
// Void is always equal to void. This is the type for the contents of the empty list in `[] == []`
// This code will never execute, but we need a true or false value to type-check
Layout::Union(UnionLayout::NonRecursive(tags)) if tags.is_empty() => {
self.code_builder
.i32_const(if lowlevel == LowLevel::Eq { 1 } else { 0 });
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_))
| Layout::Struct(_)
| Layout::Union(_)
| Layout::LambdaSet(_) => {
self.build_eq_specialized(&arg_layout, arguments, return_sym, storage);
if matches!(lowlevel, LowLevel::NotEq) {
self.code_builder.i32_eqz();
}
}
Layout::RecursivePointer => {
internal_error!(
"Tried to apply `==` to RecursivePointer values {:?}",
arguments,
)
}
}
}
fn build_eq_or_neq_number(
&mut self,
lowlevel: LowLevel,
arguments: &'a [Symbol],
return_layout: WasmLayout,
mono_layout: &Layout<'a>,
) {
use StoredValue::*;
match self.storage.get(&arguments[0]).to_owned() {
VirtualMachineStack { value_type, .. } | Local { value_type, .. } => {
self.storage.load_symbols(&mut self.code_builder, arguments);
match lowlevel {
LowLevel::Eq => match value_type {
ValueType::I32 => self.code_builder.i32_eq(),
ValueType::I64 => self.code_builder.i64_eq(),
ValueType::F32 => self.code_builder.f32_eq(),
ValueType::F64 => self.code_builder.f64_eq(),
},
LowLevel::NotEq => match value_type {
ValueType::I32 => self.code_builder.i32_ne(),
ValueType::I64 => self.code_builder.i64_ne(),
ValueType::F32 => self.code_builder.f32_ne(),
ValueType::F64 => self.code_builder.f64_ne(),
},
_ => internal_error!("Low-level op {:?} handled in the wrong place", lowlevel),
}
}
StackMemory {
format,
location: location0,
..
} => {
if let StackMemory {
location: location1,
..
} = self.storage.get(&arguments[1]).to_owned()
{
self.build_eq_num128(
format,
[location0, location1],
arguments,
return_layout,
mono_layout,
);
if matches!(lowlevel, LowLevel::NotEq) {
self.code_builder.i32_eqz();
}
}
}
}
}
fn build_eq_num128(
&mut self,
format: StackMemoryFormat,
locations: [StackMemoryLocation; 2],
arguments: &'a [Symbol],
return_layout: WasmLayout,
mono_layout: &Layout<'a>,
) {
match format {
StackMemoryFormat::Decimal => {
// Both args are finite
let first = [arguments[0]];
let second = [arguments[1]];
// TODO!
//
// dispatch_low_level(
// &mut self.code_builder,
// &mut self.storage,
// LowLevel::NumIsFinite,
// &first,
// &return_layout,
// mono_layout,
// );
// dispatch_low_level(
// &mut self.code_builder,
// &mut self.storage,
// LowLevel::NumIsFinite,
// &second,
// &return_layout,
// mono_layout,
// );
self.code_builder.i32_and();
// AND they have the same bytes
self.build_eq_num128_bytes(locations);
self.code_builder.i32_and();
}
StackMemoryFormat::Int128 => self.build_eq_num128_bytes(locations),
StackMemoryFormat::Float128 => todo!("equality for f128"),
StackMemoryFormat::DataStructure => {
internal_error!("Data structure equality is handled elsewhere")
}
}
}
/// Check that two 128-bit numbers contain the same bytes
fn build_eq_num128_bytes(&mut self, locations: [StackMemoryLocation; 2]) {
let (local0, offset0) = locations[0].local_and_offset(self.storage.stack_frame_pointer);
let (local1, offset1) = locations[1].local_and_offset(self.storage.stack_frame_pointer);
self.code_builder.get_local(local0);
self.code_builder.i64_load(Align::Bytes8, offset0);
self.code_builder.get_local(local1);
self.code_builder.i64_load(Align::Bytes8, offset1);
self.code_builder.i64_eq();
self.code_builder.get_local(local0);
self.code_builder.i64_load(Align::Bytes8, offset0 + 8);
self.code_builder.get_local(local1);
self.code_builder.i64_load(Align::Bytes8, offset1 + 8);
self.code_builder.i64_eq();
self.code_builder.i32_and();
}
/// Call a helper procedure that implements `==` for a specific data structure
fn build_eq_specialized(
&mut self,
arg_layout: &Layout<'a>,
arguments: &'a [Symbol],
return_sym: Symbol,
storage: &StoredValue,
) {
let ident_ids = self
.interns
.all_ident_ids
.get_mut(&self.env.module_id)
.unwrap();
// Get an IR expression for the call to the specialized procedure
let (specialized_call_expr, new_specializations) = self
.helper_proc_gen
.call_specialized_equals(ident_ids, arg_layout, arguments);
// If any new specializations were created, register their symbol data
for spec in new_specializations.into_iter() {
self.register_helper_proc(spec);
}
// Generate Wasm code for the IR call expression
let bool_layout = Layout::Builtin(Builtin::Bool);
self.expr(
return_sym,
self.env.arena.alloc(specialized_call_expr),
&bool_layout,
storage,
);
}
}