Refactor Wasm equality operator

This commit is contained in:
Brian Carroll 2021-12-19 21:06:57 +00:00
parent c510226c15
commit e0ffaca3df

View file

@ -16,7 +16,7 @@ use roc_reporting::internal_error;
use crate::layout::{CallConv, ReturnMethod, StackMemoryFormat, WasmLayout}; use crate::layout::{CallConv, ReturnMethod, StackMemoryFormat, WasmLayout};
use crate::low_level::{dispatch_low_level, LowlevelBuildResult}; use crate::low_level::{dispatch_low_level, LowlevelBuildResult};
use crate::storage::{Storage, StoredValue, StoredValueKind}; use crate::storage::{StackMemoryLocation, Storage, StoredValue, StoredValueKind};
use crate::wasm_module::linking::{ use crate::wasm_module::linking::{
DataSymbol, LinkingSection, RelocationSection, WasmObjectSymbol, WASM_SYM_BINDING_WEAK, DataSymbol, LinkingSection, RelocationSection, WasmObjectSymbol, WASM_SYM_BINDING_WEAK,
WASM_SYM_UNDEFINED, WASM_SYM_UNDEFINED,
@ -1081,24 +1081,17 @@ impl<'a> WasmBackend<'a> {
return_layout: WasmLayout, return_layout: WasmLayout,
storage: &StoredValue, storage: &StoredValue,
) { ) {
// Load the arguments using Zig calling convention
let (param_types, ret_type) = self.storage.load_symbols_for_call(
self.env.arena,
&mut self.code_builder,
arguments,
return_sym,
&return_layout,
CallConv::Zig,
);
use StoredValue::*; use StoredValue::*;
match self.storage.get(&arguments[0]).to_owned() { match self.storage.get(&arguments[0]).to_owned() {
VirtualMachineStack { value_type, .. } | Local { value_type, .. } => match value_type { VirtualMachineStack { value_type, .. } | Local { value_type, .. } => {
ValueType::I32 => self.code_builder.i32_eq(), self.storage.load_symbols(&mut self.code_builder, arguments);
ValueType::I64 => self.code_builder.i64_eq(), match value_type {
ValueType::F32 => self.code_builder.f32_eq(), ValueType::I32 => self.code_builder.i32_eq(),
ValueType::F64 => self.code_builder.f64_eq(), ValueType::I64 => self.code_builder.i64_eq(),
}, ValueType::F32 => self.code_builder.f32_eq(),
ValueType::F64 => self.code_builder.f64_eq(),
}
}
StackMemory { StackMemory {
format, format,
location: location0, location: location0,
@ -1109,97 +1102,123 @@ impl<'a> WasmBackend<'a> {
.. ..
} = self.storage.get(&arguments[1]).to_owned() } = self.storage.get(&arguments[1]).to_owned()
{ {
let stack_frame_pointer = self.storage.stack_frame_pointer; self.build_eq_memory(
let compare_bytes = |code_builder: &mut CodeBuilder| { format,
let (local0, offset0) = location0.local_and_offset(stack_frame_pointer); [location0, location1],
let (local1, offset1) = location1.local_and_offset(stack_frame_pointer); arguments,
return_sym,
return_layout,
storage,
)
}
}
};
}
code_builder.get_local(local0); // Equality for values in memory (as opposed to VM stack)
code_builder.i64_load(Align::Bytes8, offset0); fn build_eq_memory(
code_builder.get_local(local1); &mut self,
code_builder.i64_load(Align::Bytes8, offset1); format: StackMemoryFormat,
code_builder.i64_eq(); locations: [StackMemoryLocation; 2],
arguments: &'a [Symbol],
return_sym: Symbol,
return_layout: WasmLayout,
storage: &StoredValue,
) {
match format {
StackMemoryFormat::Decimal => {
// Both args are finite
let first = [arguments[0]];
let second = [arguments[1]];
dispatch_low_level(
&mut self.code_builder,
&mut self.storage,
LowLevel::NumIsFinite,
&first,
&return_layout,
);
dispatch_low_level(
&mut self.code_builder,
&mut self.storage,
LowLevel::NumIsFinite,
&second,
&return_layout,
);
self.code_builder.i32_and();
code_builder.get_local(local0); // AND they have the same bytes
code_builder.i64_load(Align::Bytes8, offset0 + 8); self.build_eq_help_128bit(locations);
code_builder.get_local(local1); self.code_builder.i32_and();
code_builder.i64_load(Align::Bytes8, offset1 + 8); }
code_builder.i64_eq();
code_builder.i32_and(); StackMemoryFormat::Int128 => self.build_eq_help_128bit(locations),
};
match format { StackMemoryFormat::Float128 => todo!("equality for f128"),
StackMemoryFormat::Decimal => {
// Both args are finite
let first = [arguments[0]];
let second = [arguments[1]];
dispatch_low_level(
&mut self.code_builder,
&mut self.storage,
LowLevel::NumIsFinite,
&first,
&return_layout,
);
dispatch_low_level(
&mut self.code_builder,
&mut self.storage,
LowLevel::NumIsFinite,
&second,
&return_layout,
);
self.code_builder.i32_and();
// AND they have the same bytes StackMemoryFormat::DataStructure => {
compare_bytes(&mut self.code_builder); let layout = self.symbol_layouts[&arguments[0]];
self.code_builder.i32_and(); let layout_rhs = self.symbol_layouts[&arguments[1]];
} debug_assert!(
StackMemoryFormat::Int128 => compare_bytes(&mut self.code_builder), layout == layout_rhs,
StackMemoryFormat::Float128 => todo!("equality for f128"), "Cannot do `==` comparison on different types"
StackMemoryFormat::DataStructure => { );
let layout = self.symbol_layouts[&arguments[0]];
let layout_rhs = self.symbol_layouts[&arguments[1]];
debug_assert!(
layout == layout_rhs,
"Cannot do `==` comparison on different types"
);
if layout == Layout::Builtin(Builtin::Str) { if layout == Layout::Builtin(Builtin::Str) {
self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type); let (param_types, ret_type) = self.storage.load_symbols_for_call(
} else if layout.stack_size(PTR_SIZE) == 0 { self.env.arena,
// Always true: `Unit == Unit`, or `{} == {}` &mut self.code_builder,
self.code_builder.i32_const(1); arguments,
} else { return_sym,
let ident_ids = self &return_layout,
.interns CallConv::Zig,
.all_ident_ids );
.get_mut(&self.env.module_id) self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type);
.unwrap(); } else if layout.stack_size(PTR_SIZE) == 0 {
// Always true: `Unit == Unit`, or `{} == {}`
self.code_builder.i32_const(1);
} else {
let ident_ids = self
.interns
.all_ident_ids
.get_mut(&self.env.module_id)
.unwrap();
let (replacement_expr, new_specializations) = self let (replacement_expr, new_specializations) = self
.helper_proc_gen .helper_proc_gen
.specialize_equals(ident_ids, &layout, arguments); .specialize_equals(ident_ids, &layout, arguments);
// If any new specializations were created, register their symbol data // If any new specializations were created, register their symbol data
for spec in new_specializations.into_iter() { for spec in new_specializations.into_iter() {
self.register_helper_proc(spec); self.register_helper_proc(spec);
}
let bool_layout = Layout::Builtin(Builtin::Bool);
self.build_expr(
&return_sym,
replacement_expr,
&bool_layout,
storage,
);
}
}
} }
let bool_layout = Layout::Builtin(Builtin::Bool);
self.build_expr(&return_sym, replacement_expr, &bool_layout, storage);
} }
} }
} }
} }
/// Equality helper for 128-bit numbers
fn build_eq_help_128bit(&mut self, locations: [StackMemoryLocation; 2]) {
let (local0, offset0) = locations[0].local_and_offset(self.storage.stack_frame_pointer);
let (local1, offset1) = locations[1].local_and_offset(self.storage.stack_frame_pointer);
self.code_builder.get_local(local0);
self.code_builder.i64_load(Align::Bytes8, offset0);
self.code_builder.get_local(local1);
self.code_builder.i64_load(Align::Bytes8, offset1);
self.code_builder.i64_eq();
self.code_builder.get_local(local0);
self.code_builder.i64_load(Align::Bytes8, offset0 + 8);
self.code_builder.get_local(local1);
self.code_builder.i64_load(Align::Bytes8, offset1 + 8);
self.code_builder.i64_eq();
self.code_builder.i32_and();
}
fn load_literal( fn load_literal(
&mut self, &mut self,
lit: &Literal<'a>, lit: &Literal<'a>,