Merge pull request #2245 from rtfeldman/wasm-tag-eq

Wasm tag equality
This commit is contained in:
Folkert de Vries 2021-12-21 16:24:14 +01:00 committed by GitHub
commit 18187bc43f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 467 additions and 193 deletions

View file

@ -14,9 +14,9 @@ use roc_mono::ir::{
use roc_mono::layout::{Builtin, Layout, LayoutIds, TagIdIntType, UnionLayout};
use roc_reporting::internal_error;
use crate::layout::{CallConv, ReturnMethod, WasmLayout};
use crate::low_level::{decode_low_level, LowlevelBuildResult};
use crate::storage::{Storage, StoredValue, StoredValueKind};
use crate::layout::{CallConv, ReturnMethod, StackMemoryFormat, WasmLayout};
use crate::low_level::{dispatch_low_level, LowlevelBuildResult};
use crate::storage::{StackMemoryLocation, Storage, StoredValue, StoredValueKind};
use crate::wasm_module::linking::{
DataSymbol, LinkingSection, RelocationSection, WasmObjectSymbol, WASM_SYM_BINDING_WEAK,
WASM_SYM_UNDEFINED,
@ -272,6 +272,7 @@ impl<'a> WasmBackend<'a> {
self.start_block(BlockType::from(ret_type));
for (layout, symbol) in proc.args {
self.symbol_layouts.insert(*symbol, *layout);
let arg_layout = WasmLayout::new(layout);
self.storage
.allocate(&arg_layout, *symbol, StoredValueKind::Parameter);
@ -480,6 +481,8 @@ impl<'a> WasmBackend<'a> {
// make locals for join pointer parameters
let mut jp_param_storages = Vec::with_capacity_in(parameters.len(), self.env.arena);
for parameter in parameters.iter() {
self.symbol_layouts
.insert(parameter.symbol, parameter.layout);
let wasm_layout = WasmLayout::new(&parameter.layout);
let mut param_storage = self.storage.allocate(
&wasm_layout,
@ -645,21 +648,34 @@ impl<'a> WasmBackend<'a> {
field_layouts,
structure,
} => {
if let StoredValue::StackMemory { location, .. } = self.storage.get(structure) {
let (local_id, mut offset) =
location.local_and_offset(self.storage.stack_frame_pointer);
for field in field_layouts.iter().take(*index as usize) {
offset += field.stack_size(PTR_SIZE);
self.storage.ensure_value_has_local(
&mut self.code_builder,
*sym,
storage.to_owned(),
);
let (local_id, mut offset) = match self.storage.get(structure) {
StoredValue::StackMemory { location, .. } => {
location.local_and_offset(self.storage.stack_frame_pointer)
}
self.storage.copy_value_from_memory(
&mut self.code_builder,
*sym,
StoredValue::Local {
value_type,
local_id,
offset,
);
} else {
internal_error!("Unexpected storage for {:?}", structure)
..
} => {
debug_assert!(matches!(value_type, ValueType::I32));
(*local_id, 0)
}
StoredValue::VirtualMachineStack { .. } => {
internal_error!("ensure_value_has_local didn't work")
}
};
for field in field_layouts.iter().take(*index as usize) {
offset += field.stack_size(PTR_SIZE);
}
self.storage
.copy_value_from_memory(&mut self.code_builder, *sym, local_id, offset);
}
Expr::Array { elems, elem_layout } => {
@ -1024,77 +1040,247 @@ impl<'a> WasmBackend<'a> {
return_layout: WasmLayout,
storage: &StoredValue,
) {
let (param_types, ret_type) = self.storage.load_symbols_for_call(
self.env.arena,
&mut self.code_builder,
arguments,
return_sym,
&return_layout,
CallConv::Zig,
);
use LowLevel::*;
let build_result = decode_low_level(
&mut self.code_builder,
&mut self.storage,
lowlevel,
arguments,
&return_layout,
);
use LowlevelBuildResult::*;
match build_result {
Done => {}
BuiltinCall(name) => {
self.call_zig_builtin(name, param_types, ret_type);
match lowlevel {
Eq | NotEq => self.build_eq(lowlevel, arguments, return_sym, return_layout, storage),
PtrCast => {
// Don't want Zig calling convention when casting pointers.
self.storage.load_symbols(&mut self.code_builder, arguments);
}
SpecializedEq | SpecializedNotEq => {
let layout = self.symbol_layouts[&arguments[0]];
let layout_rhs = self.symbol_layouts[&arguments[1]];
debug_assert!(
layout == layout_rhs,
"Cannot do `==` comparison on different types"
Hash => todo!("Generic hash function generation"),
// Almost all lowlevels take this branch, except for the special cases above
_ => {
// Load the arguments using Zig calling convention
let (param_types, ret_type) = self.storage.load_symbols_for_call(
self.env.arena,
&mut self.code_builder,
arguments,
return_sym,
&return_layout,
CallConv::Zig,
);
if layout == Layout::Builtin(Builtin::Str) {
self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type);
} else if layout.stack_size(PTR_SIZE) == 0 {
// If the layout has zero size, and it type-checks, the values must be equal
let value = matches!(build_result, SpecializedEq);
self.code_builder.i32_const(value as i32);
return;
} else {
let ident_ids = self
.interns
.all_ident_ids
.get_mut(&self.env.module_id)
.unwrap();
// Generate instructions OR decide which Zig function to call
let build_result = dispatch_low_level(
&mut self.code_builder,
&mut self.storage,
lowlevel,
arguments,
&return_layout,
);
let (replacement_expr, new_specializations) = self
.helper_proc_gen
.specialize_equals(ident_ids, &layout, arguments);
// If any new specializations were created, register their symbol data
for spec in new_specializations.into_iter() {
self.register_helper_proc(spec);
// Handle the result
use LowlevelBuildResult::*;
match build_result {
Done => {}
BuiltinCall(name) => {
self.call_zig_builtin(name, param_types, ret_type);
}
NotImplemented => {
todo!("Low level operation {:?}", lowlevel)
}
let bool_layout = Layout::Builtin(Builtin::Bool);
self.build_expr(&return_sym, replacement_expr, &bool_layout, storage);
}
}
}
}
if matches!(build_result, SpecializedNotEq) {
fn build_eq(
&mut self,
lowlevel: LowLevel,
arguments: &'a [Symbol],
return_sym: Symbol,
return_layout: WasmLayout,
storage: &StoredValue,
) {
let arg_layout = self.symbol_layouts[&arguments[0]];
let other_arg_layout = self.symbol_layouts[&arguments[1]];
debug_assert!(
arg_layout == other_arg_layout,
"Cannot do `==` comparison on different types"
);
match arg_layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => self.build_eq_number(lowlevel, arguments, return_layout),
Layout::Builtin(Builtin::Str) => {
let (param_types, ret_type) = self.storage.load_symbols_for_call(
self.env.arena,
&mut self.code_builder,
arguments,
return_sym,
&return_layout,
CallConv::Zig,
);
self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type);
if matches!(lowlevel, LowLevel::NotEq) {
self.code_builder.i32_eqz();
}
}
SpecializedHash => {
todo!("Specialized hash functions")
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_))
| Layout::Struct(_)
| Layout::Union(_)
| Layout::LambdaSet(_) => {
if arg_layout.stack_size(PTR_SIZE) == 0 {
// A zero-size type has only one possible value, like `{}` or `Unit`
// The arguments don't exist at runtime. Just emit True (Eq) or False (NotEq).
let result = matches!(lowlevel, LowLevel::Eq);
self.code_builder.i32_const(result as i32);
} else {
self.build_eq_specialized(&arg_layout, arguments, return_sym, storage);
if matches!(lowlevel, LowLevel::NotEq) {
self.code_builder.i32_eqz();
}
}
}
NotImplemented => {
todo!("Low level operation {:?}", lowlevel)
Layout::RecursivePointer => {
internal_error!("`==` on RecursivePointer should be converted to the parent layout")
}
}
}
fn build_eq_number(
&mut self,
lowlevel: LowLevel,
arguments: &'a [Symbol],
return_layout: WasmLayout,
) {
use StoredValue::*;
match self.storage.get(&arguments[0]).to_owned() {
VirtualMachineStack { value_type, .. } | Local { value_type, .. } => {
self.storage.load_symbols(&mut self.code_builder, arguments);
match lowlevel {
LowLevel::Eq => match value_type {
ValueType::I32 => self.code_builder.i32_eq(),
ValueType::I64 => self.code_builder.i64_eq(),
ValueType::F32 => self.code_builder.f32_eq(),
ValueType::F64 => self.code_builder.f64_eq(),
},
LowLevel::NotEq => match value_type {
ValueType::I32 => self.code_builder.i32_ne(),
ValueType::I64 => self.code_builder.i64_ne(),
ValueType::F32 => self.code_builder.f32_ne(),
ValueType::F64 => self.code_builder.f64_ne(),
},
_ => internal_error!("Low-level op {:?} handled in the wrong place", lowlevel),
}
}
StackMemory {
format,
location: location0,
..
} => {
if let StackMemory {
location: location1,
..
} = self.storage.get(&arguments[1]).to_owned()
{
self.build_eq_num128(format, [location0, location1], arguments, return_layout);
if matches!(lowlevel, LowLevel::NotEq) {
self.code_builder.i32_eqz();
}
}
}
}
}
fn build_eq_num128(
&mut self,
format: StackMemoryFormat,
locations: [StackMemoryLocation; 2],
arguments: &'a [Symbol],
return_layout: WasmLayout,
) {
match format {
StackMemoryFormat::Decimal => {
// Both args are finite
let first = [arguments[0]];
let second = [arguments[1]];
dispatch_low_level(
&mut self.code_builder,
&mut self.storage,
LowLevel::NumIsFinite,
&first,
&return_layout,
);
dispatch_low_level(
&mut self.code_builder,
&mut self.storage,
LowLevel::NumIsFinite,
&second,
&return_layout,
);
self.code_builder.i32_and();
// AND they have the same bytes
self.build_eq_num128_bytes(locations);
self.code_builder.i32_and();
}
StackMemoryFormat::Int128 => self.build_eq_num128_bytes(locations),
StackMemoryFormat::Float128 => todo!("equality for f128"),
StackMemoryFormat::DataStructure => {
internal_error!("Data structure equality is handled elsewhere")
}
}
}
/// Check that two 128-bit numbers contain the same bytes
fn build_eq_num128_bytes(&mut self, locations: [StackMemoryLocation; 2]) {
let (local0, offset0) = locations[0].local_and_offset(self.storage.stack_frame_pointer);
let (local1, offset1) = locations[1].local_and_offset(self.storage.stack_frame_pointer);
self.code_builder.get_local(local0);
self.code_builder.i64_load(Align::Bytes8, offset0);
self.code_builder.get_local(local1);
self.code_builder.i64_load(Align::Bytes8, offset1);
self.code_builder.i64_eq();
self.code_builder.get_local(local0);
self.code_builder.i64_load(Align::Bytes8, offset0 + 8);
self.code_builder.get_local(local1);
self.code_builder.i64_load(Align::Bytes8, offset1 + 8);
self.code_builder.i64_eq();
self.code_builder.i32_and();
}
/// Call a helper procedure that implements `==` for a specific data structure
fn build_eq_specialized(
&mut self,
arg_layout: &Layout<'a>,
arguments: &'a [Symbol],
return_sym: Symbol,
storage: &StoredValue,
) {
let ident_ids = self
.interns
.all_ident_ids
.get_mut(&self.env.module_id)
.unwrap();
// Get an IR expression for the call to the specialized procedure
let (specialized_call_expr, new_specializations) = self
.helper_proc_gen
.call_specialized_equals(ident_ids, arg_layout, arguments);
// If any new specializations were created, register their symbol data
for spec in new_specializations.into_iter() {
self.register_helper_proc(spec);
}
// Generate Wasm code for the IR call expression
let bool_layout = Layout::Builtin(Builtin::Bool);
self.build_expr(&return_sym, specialized_call_expr, &bool_layout, storage);
}
fn load_literal(
&mut self,
lit: &Literal<'a>,

View file

@ -11,13 +11,10 @@ use crate::wasm_module::{Align, CodeBuilder, ValueType::*};
pub enum LowlevelBuildResult {
Done,
BuiltinCall(&'static str),
SpecializedEq,
SpecializedNotEq,
SpecializedHash,
NotImplemented,
}
pub fn decode_low_level<'a>(
pub fn dispatch_low_level<'a>(
code_builder: &mut CodeBuilder<'a>,
storage: &mut Storage<'a>,
lowlevel: LowLevel,
@ -525,109 +522,15 @@ pub fn decode_low_level<'a>(
WasmLayout::StackMemory { .. } => return NotImplemented,
}
}
Eq => {
use StoredValue::*;
match storage.get(&args[0]).to_owned() {
VirtualMachineStack { value_type, .. } | Local { value_type, .. } => {
match value_type {
I32 => code_builder.i32_eq(),
I64 => code_builder.i64_eq(),
F32 => code_builder.f32_eq(),
F64 => code_builder.f64_eq(),
}
}
StackMemory {
format,
location: location0,
..
} => {
if let StackMemory {
location: location1,
..
} = storage.get(&args[1]).to_owned()
{
let stack_frame_pointer = storage.stack_frame_pointer;
let compare_bytes = |code_builder: &mut CodeBuilder| {
let (local0, offset0) = location0.local_and_offset(stack_frame_pointer);
let (local1, offset1) = location1.local_and_offset(stack_frame_pointer);
code_builder.get_local(local0);
code_builder.i64_load(Align::Bytes8, offset0);
code_builder.get_local(local1);
code_builder.i64_load(Align::Bytes8, offset1);
code_builder.i64_eq();
code_builder.get_local(local0);
code_builder.i64_load(Align::Bytes8, offset0 + 8);
code_builder.get_local(local1);
code_builder.i64_load(Align::Bytes8, offset1 + 8);
code_builder.i64_eq();
code_builder.i32_and();
};
match format {
Decimal => {
// Both args are finite
let first = [args[0]];
let second = [args[1]];
decode_low_level(
code_builder,
storage,
LowLevel::NumIsFinite,
&first,
ret_layout,
);
decode_low_level(
code_builder,
storage,
LowLevel::NumIsFinite,
&second,
ret_layout,
);
code_builder.i32_and();
// AND they have the same bytes
compare_bytes(code_builder);
code_builder.i32_and();
}
Int128 => compare_bytes(code_builder),
Float128 => return NotImplemented,
DataStructure => return SpecializedEq,
}
}
}
}
}
NotEq => match storage.get(&args[0]) {
StoredValue::VirtualMachineStack { value_type, .. }
| StoredValue::Local { value_type, .. } => match value_type {
I32 => code_builder.i32_ne(),
I64 => code_builder.i64_ne(),
F32 => code_builder.f32_ne(),
F64 => code_builder.f64_ne(),
},
StoredValue::StackMemory { format, .. } => {
if matches!(format, DataStructure) {
return SpecializedNotEq;
} else {
decode_low_level(code_builder, storage, LowLevel::Eq, args, ret_layout);
code_builder.i32_eqz();
}
}
},
And => code_builder.i32_and(),
Or => code_builder.i32_or(),
Not => code_builder.i32_eqz(),
Hash => return SpecializedHash,
ExpectTrue => return NotImplemented,
PtrCast => {
// We don't need any instructions here, since we've already loaded the value.
// PtrCast just creates separate Symbols and Layouts for the argument and return value.
// This is used for pointer math in refcounting and for pointer equality
}
RefCountInc => return BuiltinCall(bitcode::UTILS_INCREF),
RefCountDec => return BuiltinCall(bitcode::UTILS_DECREF),
Eq | NotEq | Hash | PtrCast => {
internal_error!("{:?} should be handled in backend.rs", lowlevel)
}
}
Done
}

View file

@ -319,9 +319,11 @@ impl<'a> Storage<'a> {
code_builder.i64_load(align, offset);
} else if *size <= 12 && BUILTINS_ZIG_VERSION == ZigVersion::Zig9 {
code_builder.i64_load(align, offset);
code_builder.get_local(local_id);
code_builder.i32_load(align, offset + 8);
} else {
code_builder.i64_load(align, offset);
code_builder.get_local(local_id);
code_builder.i64_load(align, offset + 8);
}
}