mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-29 06:44:46 +00:00
commit
18187bc43f
5 changed files with 467 additions and 193 deletions
|
@ -14,9 +14,9 @@ use roc_mono::ir::{
|
|||
use roc_mono::layout::{Builtin, Layout, LayoutIds, TagIdIntType, UnionLayout};
|
||||
use roc_reporting::internal_error;
|
||||
|
||||
use crate::layout::{CallConv, ReturnMethod, WasmLayout};
|
||||
use crate::low_level::{decode_low_level, LowlevelBuildResult};
|
||||
use crate::storage::{Storage, StoredValue, StoredValueKind};
|
||||
use crate::layout::{CallConv, ReturnMethod, StackMemoryFormat, WasmLayout};
|
||||
use crate::low_level::{dispatch_low_level, LowlevelBuildResult};
|
||||
use crate::storage::{StackMemoryLocation, Storage, StoredValue, StoredValueKind};
|
||||
use crate::wasm_module::linking::{
|
||||
DataSymbol, LinkingSection, RelocationSection, WasmObjectSymbol, WASM_SYM_BINDING_WEAK,
|
||||
WASM_SYM_UNDEFINED,
|
||||
|
@ -272,6 +272,7 @@ impl<'a> WasmBackend<'a> {
|
|||
self.start_block(BlockType::from(ret_type));
|
||||
|
||||
for (layout, symbol) in proc.args {
|
||||
self.symbol_layouts.insert(*symbol, *layout);
|
||||
let arg_layout = WasmLayout::new(layout);
|
||||
self.storage
|
||||
.allocate(&arg_layout, *symbol, StoredValueKind::Parameter);
|
||||
|
@ -480,6 +481,8 @@ impl<'a> WasmBackend<'a> {
|
|||
// make locals for join pointer parameters
|
||||
let mut jp_param_storages = Vec::with_capacity_in(parameters.len(), self.env.arena);
|
||||
for parameter in parameters.iter() {
|
||||
self.symbol_layouts
|
||||
.insert(parameter.symbol, parameter.layout);
|
||||
let wasm_layout = WasmLayout::new(¶meter.layout);
|
||||
let mut param_storage = self.storage.allocate(
|
||||
&wasm_layout,
|
||||
|
@ -645,21 +648,34 @@ impl<'a> WasmBackend<'a> {
|
|||
field_layouts,
|
||||
structure,
|
||||
} => {
|
||||
if let StoredValue::StackMemory { location, .. } = self.storage.get(structure) {
|
||||
let (local_id, mut offset) =
|
||||
location.local_and_offset(self.storage.stack_frame_pointer);
|
||||
for field in field_layouts.iter().take(*index as usize) {
|
||||
offset += field.stack_size(PTR_SIZE);
|
||||
self.storage.ensure_value_has_local(
|
||||
&mut self.code_builder,
|
||||
*sym,
|
||||
storage.to_owned(),
|
||||
);
|
||||
let (local_id, mut offset) = match self.storage.get(structure) {
|
||||
StoredValue::StackMemory { location, .. } => {
|
||||
location.local_and_offset(self.storage.stack_frame_pointer)
|
||||
}
|
||||
self.storage.copy_value_from_memory(
|
||||
&mut self.code_builder,
|
||||
*sym,
|
||||
|
||||
StoredValue::Local {
|
||||
value_type,
|
||||
local_id,
|
||||
offset,
|
||||
);
|
||||
} else {
|
||||
internal_error!("Unexpected storage for {:?}", structure)
|
||||
..
|
||||
} => {
|
||||
debug_assert!(matches!(value_type, ValueType::I32));
|
||||
(*local_id, 0)
|
||||
}
|
||||
|
||||
StoredValue::VirtualMachineStack { .. } => {
|
||||
internal_error!("ensure_value_has_local didn't work")
|
||||
}
|
||||
};
|
||||
for field in field_layouts.iter().take(*index as usize) {
|
||||
offset += field.stack_size(PTR_SIZE);
|
||||
}
|
||||
self.storage
|
||||
.copy_value_from_memory(&mut self.code_builder, *sym, local_id, offset);
|
||||
}
|
||||
|
||||
Expr::Array { elems, elem_layout } => {
|
||||
|
@ -1024,77 +1040,247 @@ impl<'a> WasmBackend<'a> {
|
|||
return_layout: WasmLayout,
|
||||
storage: &StoredValue,
|
||||
) {
|
||||
let (param_types, ret_type) = self.storage.load_symbols_for_call(
|
||||
self.env.arena,
|
||||
&mut self.code_builder,
|
||||
arguments,
|
||||
return_sym,
|
||||
&return_layout,
|
||||
CallConv::Zig,
|
||||
);
|
||||
use LowLevel::*;
|
||||
|
||||
let build_result = decode_low_level(
|
||||
&mut self.code_builder,
|
||||
&mut self.storage,
|
||||
lowlevel,
|
||||
arguments,
|
||||
&return_layout,
|
||||
);
|
||||
use LowlevelBuildResult::*;
|
||||
|
||||
match build_result {
|
||||
Done => {}
|
||||
BuiltinCall(name) => {
|
||||
self.call_zig_builtin(name, param_types, ret_type);
|
||||
match lowlevel {
|
||||
Eq | NotEq => self.build_eq(lowlevel, arguments, return_sym, return_layout, storage),
|
||||
PtrCast => {
|
||||
// Don't want Zig calling convention when casting pointers.
|
||||
self.storage.load_symbols(&mut self.code_builder, arguments);
|
||||
}
|
||||
SpecializedEq | SpecializedNotEq => {
|
||||
let layout = self.symbol_layouts[&arguments[0]];
|
||||
let layout_rhs = self.symbol_layouts[&arguments[1]];
|
||||
debug_assert!(
|
||||
layout == layout_rhs,
|
||||
"Cannot do `==` comparison on different types"
|
||||
Hash => todo!("Generic hash function generation"),
|
||||
|
||||
// Almost all lowlevels take this branch, except for the special cases above
|
||||
_ => {
|
||||
// Load the arguments using Zig calling convention
|
||||
let (param_types, ret_type) = self.storage.load_symbols_for_call(
|
||||
self.env.arena,
|
||||
&mut self.code_builder,
|
||||
arguments,
|
||||
return_sym,
|
||||
&return_layout,
|
||||
CallConv::Zig,
|
||||
);
|
||||
|
||||
if layout == Layout::Builtin(Builtin::Str) {
|
||||
self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type);
|
||||
} else if layout.stack_size(PTR_SIZE) == 0 {
|
||||
// If the layout has zero size, and it type-checks, the values must be equal
|
||||
let value = matches!(build_result, SpecializedEq);
|
||||
self.code_builder.i32_const(value as i32);
|
||||
return;
|
||||
} else {
|
||||
let ident_ids = self
|
||||
.interns
|
||||
.all_ident_ids
|
||||
.get_mut(&self.env.module_id)
|
||||
.unwrap();
|
||||
// Generate instructions OR decide which Zig function to call
|
||||
let build_result = dispatch_low_level(
|
||||
&mut self.code_builder,
|
||||
&mut self.storage,
|
||||
lowlevel,
|
||||
arguments,
|
||||
&return_layout,
|
||||
);
|
||||
|
||||
let (replacement_expr, new_specializations) = self
|
||||
.helper_proc_gen
|
||||
.specialize_equals(ident_ids, &layout, arguments);
|
||||
|
||||
// If any new specializations were created, register their symbol data
|
||||
for spec in new_specializations.into_iter() {
|
||||
self.register_helper_proc(spec);
|
||||
// Handle the result
|
||||
use LowlevelBuildResult::*;
|
||||
match build_result {
|
||||
Done => {}
|
||||
BuiltinCall(name) => {
|
||||
self.call_zig_builtin(name, param_types, ret_type);
|
||||
}
|
||||
NotImplemented => {
|
||||
todo!("Low level operation {:?}", lowlevel)
|
||||
}
|
||||
|
||||
let bool_layout = Layout::Builtin(Builtin::Bool);
|
||||
self.build_expr(&return_sym, replacement_expr, &bool_layout, storage);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(build_result, SpecializedNotEq) {
|
||||
fn build_eq(
|
||||
&mut self,
|
||||
lowlevel: LowLevel,
|
||||
arguments: &'a [Symbol],
|
||||
return_sym: Symbol,
|
||||
return_layout: WasmLayout,
|
||||
storage: &StoredValue,
|
||||
) {
|
||||
let arg_layout = self.symbol_layouts[&arguments[0]];
|
||||
let other_arg_layout = self.symbol_layouts[&arguments[1]];
|
||||
debug_assert!(
|
||||
arg_layout == other_arg_layout,
|
||||
"Cannot do `==` comparison on different types"
|
||||
);
|
||||
|
||||
match arg_layout {
|
||||
Layout::Builtin(
|
||||
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
|
||||
) => self.build_eq_number(lowlevel, arguments, return_layout),
|
||||
|
||||
Layout::Builtin(Builtin::Str) => {
|
||||
let (param_types, ret_type) = self.storage.load_symbols_for_call(
|
||||
self.env.arena,
|
||||
&mut self.code_builder,
|
||||
arguments,
|
||||
return_sym,
|
||||
&return_layout,
|
||||
CallConv::Zig,
|
||||
);
|
||||
self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type);
|
||||
if matches!(lowlevel, LowLevel::NotEq) {
|
||||
self.code_builder.i32_eqz();
|
||||
}
|
||||
}
|
||||
SpecializedHash => {
|
||||
todo!("Specialized hash functions")
|
||||
|
||||
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_))
|
||||
| Layout::Struct(_)
|
||||
| Layout::Union(_)
|
||||
| Layout::LambdaSet(_) => {
|
||||
if arg_layout.stack_size(PTR_SIZE) == 0 {
|
||||
// A zero-size type has only one possible value, like `{}` or `Unit`
|
||||
// The arguments don't exist at runtime. Just emit True (Eq) or False (NotEq).
|
||||
let result = matches!(lowlevel, LowLevel::Eq);
|
||||
self.code_builder.i32_const(result as i32);
|
||||
} else {
|
||||
self.build_eq_specialized(&arg_layout, arguments, return_sym, storage);
|
||||
if matches!(lowlevel, LowLevel::NotEq) {
|
||||
self.code_builder.i32_eqz();
|
||||
}
|
||||
}
|
||||
}
|
||||
NotImplemented => {
|
||||
todo!("Low level operation {:?}", lowlevel)
|
||||
|
||||
Layout::RecursivePointer => {
|
||||
internal_error!("`==` on RecursivePointer should be converted to the parent layout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_eq_number(
|
||||
&mut self,
|
||||
lowlevel: LowLevel,
|
||||
arguments: &'a [Symbol],
|
||||
return_layout: WasmLayout,
|
||||
) {
|
||||
use StoredValue::*;
|
||||
match self.storage.get(&arguments[0]).to_owned() {
|
||||
VirtualMachineStack { value_type, .. } | Local { value_type, .. } => {
|
||||
self.storage.load_symbols(&mut self.code_builder, arguments);
|
||||
match lowlevel {
|
||||
LowLevel::Eq => match value_type {
|
||||
ValueType::I32 => self.code_builder.i32_eq(),
|
||||
ValueType::I64 => self.code_builder.i64_eq(),
|
||||
ValueType::F32 => self.code_builder.f32_eq(),
|
||||
ValueType::F64 => self.code_builder.f64_eq(),
|
||||
},
|
||||
LowLevel::NotEq => match value_type {
|
||||
ValueType::I32 => self.code_builder.i32_ne(),
|
||||
ValueType::I64 => self.code_builder.i64_ne(),
|
||||
ValueType::F32 => self.code_builder.f32_ne(),
|
||||
ValueType::F64 => self.code_builder.f64_ne(),
|
||||
},
|
||||
_ => internal_error!("Low-level op {:?} handled in the wrong place", lowlevel),
|
||||
}
|
||||
}
|
||||
StackMemory {
|
||||
format,
|
||||
location: location0,
|
||||
..
|
||||
} => {
|
||||
if let StackMemory {
|
||||
location: location1,
|
||||
..
|
||||
} = self.storage.get(&arguments[1]).to_owned()
|
||||
{
|
||||
self.build_eq_num128(format, [location0, location1], arguments, return_layout);
|
||||
if matches!(lowlevel, LowLevel::NotEq) {
|
||||
self.code_builder.i32_eqz();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_eq_num128(
|
||||
&mut self,
|
||||
format: StackMemoryFormat,
|
||||
locations: [StackMemoryLocation; 2],
|
||||
arguments: &'a [Symbol],
|
||||
return_layout: WasmLayout,
|
||||
) {
|
||||
match format {
|
||||
StackMemoryFormat::Decimal => {
|
||||
// Both args are finite
|
||||
let first = [arguments[0]];
|
||||
let second = [arguments[1]];
|
||||
dispatch_low_level(
|
||||
&mut self.code_builder,
|
||||
&mut self.storage,
|
||||
LowLevel::NumIsFinite,
|
||||
&first,
|
||||
&return_layout,
|
||||
);
|
||||
dispatch_low_level(
|
||||
&mut self.code_builder,
|
||||
&mut self.storage,
|
||||
LowLevel::NumIsFinite,
|
||||
&second,
|
||||
&return_layout,
|
||||
);
|
||||
self.code_builder.i32_and();
|
||||
|
||||
// AND they have the same bytes
|
||||
self.build_eq_num128_bytes(locations);
|
||||
self.code_builder.i32_and();
|
||||
}
|
||||
|
||||
StackMemoryFormat::Int128 => self.build_eq_num128_bytes(locations),
|
||||
|
||||
StackMemoryFormat::Float128 => todo!("equality for f128"),
|
||||
|
||||
StackMemoryFormat::DataStructure => {
|
||||
internal_error!("Data structure equality is handled elsewhere")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that two 128-bit numbers contain the same bytes
|
||||
fn build_eq_num128_bytes(&mut self, locations: [StackMemoryLocation; 2]) {
|
||||
let (local0, offset0) = locations[0].local_and_offset(self.storage.stack_frame_pointer);
|
||||
let (local1, offset1) = locations[1].local_and_offset(self.storage.stack_frame_pointer);
|
||||
|
||||
self.code_builder.get_local(local0);
|
||||
self.code_builder.i64_load(Align::Bytes8, offset0);
|
||||
self.code_builder.get_local(local1);
|
||||
self.code_builder.i64_load(Align::Bytes8, offset1);
|
||||
self.code_builder.i64_eq();
|
||||
|
||||
self.code_builder.get_local(local0);
|
||||
self.code_builder.i64_load(Align::Bytes8, offset0 + 8);
|
||||
self.code_builder.get_local(local1);
|
||||
self.code_builder.i64_load(Align::Bytes8, offset1 + 8);
|
||||
self.code_builder.i64_eq();
|
||||
|
||||
self.code_builder.i32_and();
|
||||
}
|
||||
|
||||
/// Call a helper procedure that implements `==` for a specific data structure
|
||||
fn build_eq_specialized(
|
||||
&mut self,
|
||||
arg_layout: &Layout<'a>,
|
||||
arguments: &'a [Symbol],
|
||||
return_sym: Symbol,
|
||||
storage: &StoredValue,
|
||||
) {
|
||||
let ident_ids = self
|
||||
.interns
|
||||
.all_ident_ids
|
||||
.get_mut(&self.env.module_id)
|
||||
.unwrap();
|
||||
|
||||
// Get an IR expression for the call to the specialized procedure
|
||||
let (specialized_call_expr, new_specializations) = self
|
||||
.helper_proc_gen
|
||||
.call_specialized_equals(ident_ids, arg_layout, arguments);
|
||||
|
||||
// If any new specializations were created, register their symbol data
|
||||
for spec in new_specializations.into_iter() {
|
||||
self.register_helper_proc(spec);
|
||||
}
|
||||
|
||||
// Generate Wasm code for the IR call expression
|
||||
let bool_layout = Layout::Builtin(Builtin::Bool);
|
||||
self.build_expr(&return_sym, specialized_call_expr, &bool_layout, storage);
|
||||
}
|
||||
|
||||
fn load_literal(
|
||||
&mut self,
|
||||
lit: &Literal<'a>,
|
||||
|
|
|
@ -11,13 +11,10 @@ use crate::wasm_module::{Align, CodeBuilder, ValueType::*};
|
|||
pub enum LowlevelBuildResult {
|
||||
Done,
|
||||
BuiltinCall(&'static str),
|
||||
SpecializedEq,
|
||||
SpecializedNotEq,
|
||||
SpecializedHash,
|
||||
NotImplemented,
|
||||
}
|
||||
|
||||
pub fn decode_low_level<'a>(
|
||||
pub fn dispatch_low_level<'a>(
|
||||
code_builder: &mut CodeBuilder<'a>,
|
||||
storage: &mut Storage<'a>,
|
||||
lowlevel: LowLevel,
|
||||
|
@ -525,109 +522,15 @@ pub fn decode_low_level<'a>(
|
|||
WasmLayout::StackMemory { .. } => return NotImplemented,
|
||||
}
|
||||
}
|
||||
Eq => {
|
||||
use StoredValue::*;
|
||||
match storage.get(&args[0]).to_owned() {
|
||||
VirtualMachineStack { value_type, .. } | Local { value_type, .. } => {
|
||||
match value_type {
|
||||
I32 => code_builder.i32_eq(),
|
||||
I64 => code_builder.i64_eq(),
|
||||
F32 => code_builder.f32_eq(),
|
||||
F64 => code_builder.f64_eq(),
|
||||
}
|
||||
}
|
||||
StackMemory {
|
||||
format,
|
||||
location: location0,
|
||||
..
|
||||
} => {
|
||||
if let StackMemory {
|
||||
location: location1,
|
||||
..
|
||||
} = storage.get(&args[1]).to_owned()
|
||||
{
|
||||
let stack_frame_pointer = storage.stack_frame_pointer;
|
||||
let compare_bytes = |code_builder: &mut CodeBuilder| {
|
||||
let (local0, offset0) = location0.local_and_offset(stack_frame_pointer);
|
||||
let (local1, offset1) = location1.local_and_offset(stack_frame_pointer);
|
||||
|
||||
code_builder.get_local(local0);
|
||||
code_builder.i64_load(Align::Bytes8, offset0);
|
||||
code_builder.get_local(local1);
|
||||
code_builder.i64_load(Align::Bytes8, offset1);
|
||||
code_builder.i64_eq();
|
||||
|
||||
code_builder.get_local(local0);
|
||||
code_builder.i64_load(Align::Bytes8, offset0 + 8);
|
||||
code_builder.get_local(local1);
|
||||
code_builder.i64_load(Align::Bytes8, offset1 + 8);
|
||||
code_builder.i64_eq();
|
||||
|
||||
code_builder.i32_and();
|
||||
};
|
||||
|
||||
match format {
|
||||
Decimal => {
|
||||
// Both args are finite
|
||||
let first = [args[0]];
|
||||
let second = [args[1]];
|
||||
decode_low_level(
|
||||
code_builder,
|
||||
storage,
|
||||
LowLevel::NumIsFinite,
|
||||
&first,
|
||||
ret_layout,
|
||||
);
|
||||
decode_low_level(
|
||||
code_builder,
|
||||
storage,
|
||||
LowLevel::NumIsFinite,
|
||||
&second,
|
||||
ret_layout,
|
||||
);
|
||||
code_builder.i32_and();
|
||||
|
||||
// AND they have the same bytes
|
||||
compare_bytes(code_builder);
|
||||
code_builder.i32_and();
|
||||
}
|
||||
Int128 => compare_bytes(code_builder),
|
||||
Float128 => return NotImplemented,
|
||||
DataStructure => return SpecializedEq,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
NotEq => match storage.get(&args[0]) {
|
||||
StoredValue::VirtualMachineStack { value_type, .. }
|
||||
| StoredValue::Local { value_type, .. } => match value_type {
|
||||
I32 => code_builder.i32_ne(),
|
||||
I64 => code_builder.i64_ne(),
|
||||
F32 => code_builder.f32_ne(),
|
||||
F64 => code_builder.f64_ne(),
|
||||
},
|
||||
StoredValue::StackMemory { format, .. } => {
|
||||
if matches!(format, DataStructure) {
|
||||
return SpecializedNotEq;
|
||||
} else {
|
||||
decode_low_level(code_builder, storage, LowLevel::Eq, args, ret_layout);
|
||||
code_builder.i32_eqz();
|
||||
}
|
||||
}
|
||||
},
|
||||
And => code_builder.i32_and(),
|
||||
Or => code_builder.i32_or(),
|
||||
Not => code_builder.i32_eqz(),
|
||||
Hash => return SpecializedHash,
|
||||
ExpectTrue => return NotImplemented,
|
||||
PtrCast => {
|
||||
// We don't need any instructions here, since we've already loaded the value.
|
||||
// PtrCast just creates separate Symbols and Layouts for the argument and return value.
|
||||
// This is used for pointer math in refcounting and for pointer equality
|
||||
}
|
||||
RefCountInc => return BuiltinCall(bitcode::UTILS_INCREF),
|
||||
RefCountDec => return BuiltinCall(bitcode::UTILS_DECREF),
|
||||
Eq | NotEq | Hash | PtrCast => {
|
||||
internal_error!("{:?} should be handled in backend.rs", lowlevel)
|
||||
}
|
||||
}
|
||||
Done
|
||||
}
|
||||
|
|
|
@ -319,9 +319,11 @@ impl<'a> Storage<'a> {
|
|||
code_builder.i64_load(align, offset);
|
||||
} else if *size <= 12 && BUILTINS_ZIG_VERSION == ZigVersion::Zig9 {
|
||||
code_builder.i64_load(align, offset);
|
||||
code_builder.get_local(local_id);
|
||||
code_builder.i32_load(align, offset + 8);
|
||||
} else {
|
||||
code_builder.i64_load(align, offset);
|
||||
code_builder.get_local(local_id);
|
||||
code_builder.i64_load(align, offset + 8);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -196,7 +196,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
|
||||
/// Replace a generic `Lowlevel::Eq` call with a specialized helper proc.
|
||||
/// The helper procs themselves are to be generated later with `generate_procs`
|
||||
pub fn specialize_equals(
|
||||
pub fn call_specialized_equals(
|
||||
&mut self,
|
||||
ident_ids: &mut IdentIds,
|
||||
layout: &Layout<'a>,
|
||||
|
@ -690,7 +690,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
}
|
||||
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
|
||||
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts),
|
||||
Layout::Union(_) => eq_todo(),
|
||||
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout),
|
||||
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
|
||||
Layout::RecursivePointer => eq_todo(),
|
||||
};
|
||||
|
@ -715,9 +715,9 @@ impl<'a> CodeGenHelp<'a> {
|
|||
ptr2: Symbol,
|
||||
following: &'a Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
let ptr1_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr1));
|
||||
let ptr2_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr2));
|
||||
let ptr_eq = self.create_symbol(ident_ids, &format!("eq_{:?}_{:?}", ptr1_addr, ptr2_addr));
|
||||
let ptr1_addr = self.create_symbol(ident_ids, "addr1");
|
||||
let ptr2_addr = self.create_symbol(ident_ids, "addr2");
|
||||
let ptr_eq = self.create_symbol(ident_ids, "eq_addr");
|
||||
|
||||
Stmt::Let(
|
||||
ptr1_addr,
|
||||
|
@ -780,7 +780,9 @@ impl<'a> CodeGenHelp<'a> {
|
|||
fn eq_struct(&self, ident_ids: &mut IdentIds, field_layouts: &'a [Layout<'a>]) -> Stmt<'a> {
|
||||
let else_clause = self.eq_fields(
|
||||
ident_ids,
|
||||
0,
|
||||
field_layouts,
|
||||
None,
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
);
|
||||
|
@ -795,14 +797,15 @@ impl<'a> CodeGenHelp<'a> {
|
|||
fn eq_fields(
|
||||
&self,
|
||||
ident_ids: &mut IdentIds,
|
||||
tag_id: u64,
|
||||
field_layouts: &'a [Layout<'a>],
|
||||
rec_ptr_layout: Option<Layout<'a>>,
|
||||
arguments: &'a [Symbol],
|
||||
following: Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
let mut stmt = following;
|
||||
for (i, layout) in field_layouts.iter().enumerate().rev() {
|
||||
let field1_name = format!("{:?}_field_{}", arguments[0], i);
|
||||
let field1_sym = self.create_symbol(ident_ids, &field1_name);
|
||||
let field1_sym = self.create_symbol(ident_ids, &format!("field_1_{}_{}", tag_id, i));
|
||||
let field1_expr = Expr::StructAtIndex {
|
||||
index: i as u64,
|
||||
field_layouts,
|
||||
|
@ -810,8 +813,7 @@ impl<'a> CodeGenHelp<'a> {
|
|||
};
|
||||
let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next);
|
||||
|
||||
let field2_name = format!("{:?}_field_{}", arguments[1], i);
|
||||
let field2_sym = self.create_symbol(ident_ids, &field2_name);
|
||||
let field2_sym = self.create_symbol(ident_ids, &format!("field_2_{}_{}", tag_id, i));
|
||||
let field2_expr = Expr::StructAtIndex {
|
||||
index: i as u64,
|
||||
field_layouts,
|
||||
|
@ -820,7 +822,13 @@ impl<'a> CodeGenHelp<'a> {
|
|||
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
|
||||
|
||||
let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]);
|
||||
let eq_call_expr = self.apply_op_to_sub_layout(HelperOp::Eq, layout, sub_layout_args);
|
||||
let sub_layout = match (layout, rec_ptr_layout) {
|
||||
(Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout),
|
||||
_ => layout,
|
||||
};
|
||||
|
||||
let eq_call_expr =
|
||||
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, sub_layout_args);
|
||||
let eq_call_name = format!("eq_call_{}", i);
|
||||
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
|
||||
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
|
||||
|
@ -838,6 +846,181 @@ impl<'a> CodeGenHelp<'a> {
|
|||
}
|
||||
stmt
|
||||
}
|
||||
|
||||
fn eq_tag_union(&self, ident_ids: &mut IdentIds, union_layout: UnionLayout<'a>) -> Stmt<'a> {
|
||||
use UnionLayout::*;
|
||||
|
||||
let main_stmt = match union_layout {
|
||||
NonRecursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None),
|
||||
|
||||
Recursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None),
|
||||
|
||||
NonNullableUnwrapped(field_layouts) => self.eq_fields(
|
||||
ident_ids,
|
||||
0,
|
||||
field_layouts,
|
||||
Some(Layout::Union(union_layout)),
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
),
|
||||
|
||||
NullableWrapped {
|
||||
other_tags,
|
||||
nullable_id,
|
||||
} => self.eq_tag_union_help(ident_ids, union_layout, other_tags, Some(nullable_id)),
|
||||
|
||||
NullableUnwrapped {
|
||||
other_fields,
|
||||
nullable_id: n,
|
||||
} => self.eq_tag_union_help(
|
||||
ident_ids,
|
||||
union_layout,
|
||||
self.arena.alloc([other_fields]),
|
||||
Some(n as u16),
|
||||
),
|
||||
};
|
||||
|
||||
self.if_pointers_equal_return_true(
|
||||
ident_ids,
|
||||
Symbol::ARG_1,
|
||||
Symbol::ARG_2,
|
||||
self.arena.alloc(main_stmt),
|
||||
)
|
||||
}
|
||||
|
||||
fn eq_tag_union_help(
|
||||
&self,
|
||||
ident_ids: &mut IdentIds,
|
||||
union_layout: UnionLayout<'a>,
|
||||
tag_layouts: &'a [&'a [Layout<'a>]],
|
||||
nullable_id: Option<u16>,
|
||||
) -> Stmt<'a> {
|
||||
let tag_id_layout = union_layout.tag_id_layout();
|
||||
|
||||
let tag_id_a = self.create_symbol(ident_ids, "tag_id_a");
|
||||
let tag_id_a_stmt = |next| {
|
||||
Stmt::Let(
|
||||
tag_id_a,
|
||||
Expr::GetTagId {
|
||||
structure: Symbol::ARG_1,
|
||||
union_layout,
|
||||
},
|
||||
tag_id_layout,
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
let tag_id_b = self.create_symbol(ident_ids, "tag_id_b");
|
||||
let tag_id_b_stmt = |next| {
|
||||
Stmt::Let(
|
||||
tag_id_b,
|
||||
Expr::GetTagId {
|
||||
structure: Symbol::ARG_2,
|
||||
union_layout,
|
||||
},
|
||||
tag_id_layout,
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
let tag_ids_eq = self.create_symbol(ident_ids, "tag_ids_eq");
|
||||
let tag_ids_eq_stmt = |next| {
|
||||
Stmt::Let(
|
||||
tag_ids_eq,
|
||||
Expr::Call(Call {
|
||||
call_type: CallType::LowLevel {
|
||||
op: LowLevel::Eq,
|
||||
update_mode: UpdateModeId::BACKEND_DUMMY,
|
||||
},
|
||||
arguments: self.arena.alloc([tag_id_a, tag_id_b]),
|
||||
}),
|
||||
LAYOUT_BOOL,
|
||||
next,
|
||||
)
|
||||
};
|
||||
|
||||
let if_equal_ids_stmt = |next| Stmt::Switch {
|
||||
cond_symbol: tag_ids_eq,
|
||||
cond_layout: LAYOUT_BOOL,
|
||||
branches: self
|
||||
.arena
|
||||
.alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]),
|
||||
default_branch: (BranchInfo::None, next),
|
||||
ret_layout: LAYOUT_BOOL,
|
||||
};
|
||||
|
||||
//
|
||||
// Switch statement by tag ID
|
||||
//
|
||||
|
||||
let mut tag_branches = Vec::with_capacity_in(tag_layouts.len(), self.arena);
|
||||
|
||||
// If there's a null tag, check it first. We might not need to load any data from memory.
|
||||
if let Some(id) = nullable_id {
|
||||
tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE)))
|
||||
}
|
||||
|
||||
let recursive_ptr_layout = Some(Layout::Union(union_layout));
|
||||
|
||||
let mut tag_id: u64 = 0;
|
||||
for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) {
|
||||
if let Some(null_id) = nullable_id {
|
||||
if tag_id == null_id as u64 {
|
||||
tag_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
tag_branches.push((
|
||||
tag_id,
|
||||
BranchInfo::None,
|
||||
self.eq_fields(
|
||||
ident_ids,
|
||||
tag_id,
|
||||
field_layouts,
|
||||
recursive_ptr_layout,
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
),
|
||||
));
|
||||
|
||||
tag_id += 1;
|
||||
}
|
||||
|
||||
let tag_switch_stmt = Stmt::Switch {
|
||||
cond_symbol: tag_id_a,
|
||||
cond_layout: tag_id_layout,
|
||||
branches: tag_branches.into_bump_slice(),
|
||||
default_branch: (
|
||||
BranchInfo::None,
|
||||
self.arena.alloc(self.eq_fields(
|
||||
ident_ids,
|
||||
tag_id,
|
||||
tag_layouts.last().unwrap(),
|
||||
recursive_ptr_layout,
|
||||
&[Symbol::ARG_1, Symbol::ARG_2],
|
||||
Stmt::Ret(Symbol::BOOL_TRUE),
|
||||
)),
|
||||
),
|
||||
ret_layout: LAYOUT_BOOL,
|
||||
};
|
||||
|
||||
//
|
||||
// combine all the statments
|
||||
//
|
||||
tag_id_a_stmt(self.arena.alloc(
|
||||
//
|
||||
tag_id_b_stmt(self.arena.alloc(
|
||||
//
|
||||
tag_ids_eq_stmt(self.arena.alloc(
|
||||
//
|
||||
if_equal_ids_stmt(self.arena.alloc(
|
||||
//
|
||||
tag_switch_stmt,
|
||||
)),
|
||||
)),
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to derive a debug function name from a layout
|
||||
|
@ -865,7 +1048,7 @@ fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool {
|
|||
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_))
|
||||
| Layout::Struct(_)
|
||||
| Layout::Union(_)
|
||||
| Layout::LambdaSet(_)
|
||||
| Layout::RecursivePointer => true,
|
||||
| Layout::LambdaSet(_) => true,
|
||||
Layout::RecursivePointer => false,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -179,7 +179,7 @@ fn record() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn unit() {
|
||||
assert_evals_to!("Unit == Unit", true, bool);
|
||||
assert_evals_to!("Unit != Unit", false, bool);
|
||||
|
@ -231,7 +231,7 @@ fn large_str() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn eq_result_tag_true() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -251,7 +251,7 @@ fn eq_result_tag_true() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn eq_result_tag_false() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -271,7 +271,7 @@ fn eq_result_tag_false() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn eq_expr() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -293,7 +293,7 @@ fn eq_expr() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn eq_linked_list() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -351,7 +351,7 @@ fn eq_linked_list() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn eq_linked_list_false() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -373,7 +373,7 @@ fn eq_linked_list_false() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn eq_nullable_expr() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -502,7 +502,7 @@ fn list_neq_nested() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn compare_union_same_content() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -524,7 +524,7 @@ fn compare_union_same_content() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn compare_recursive_union_same_content() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
@ -546,7 +546,7 @@ fn compare_recursive_union_same_content() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm"))]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn compare_nullable_recursive_union_same_content() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue