Equality operator for records

This commit is contained in:
Brian Carroll 2021-12-18 17:25:00 +00:00
parent 55f5956175
commit a1d883600c
4 changed files with 267 additions and 99 deletions

View file

@ -1049,9 +1049,19 @@ impl<'a> WasmBackend<'a> {
} }
SpecializedEq | SpecializedNotEq => { SpecializedEq | SpecializedNotEq => {
let layout = self.symbol_layouts[&arguments[0]]; let layout = self.symbol_layouts[&arguments[0]];
let layout_rhs = self.symbol_layouts[&arguments[1]];
debug_assert!(
layout == layout_rhs,
"Cannot do `==` comparison on different types"
);
if layout == Layout::Builtin(Builtin::Str) { if layout == Layout::Builtin(Builtin::Str) {
self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type); self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type);
} else if layout.stack_size(PTR_SIZE) == 0 {
// If the layout has zero size, and it type-checks, the values must be equal
let value = matches!(build_result, SpecializedEq);
self.code_builder.i32_const(value as i32);
return;
} else { } else {
let ident_ids = self let ident_ids = self
.interns .interns
@ -1068,7 +1078,8 @@ impl<'a> WasmBackend<'a> {
self.register_helper_proc(spec); self.register_helper_proc(spec);
} }
self.build_expr(&return_sym, replacement_expr, &layout, storage); let bool_layout = Layout::Builtin(Builtin::Bool);
self.build_expr(&return_sym, replacement_expr, &bool_layout, storage);
} }
if matches!(build_result, SpecializedNotEq) { if matches!(build_result, SpecializedNotEq) {

View file

@ -7,6 +7,7 @@ use crate::layout::{StackMemoryFormat::*, WasmLayout};
use crate::storage::{Storage, StoredValue}; use crate::storage::{Storage, StoredValue};
use crate::wasm_module::{Align, CodeBuilder, ValueType::*}; use crate::wasm_module::{Align, CodeBuilder, ValueType::*};
#[derive(Debug)]
pub enum LowlevelBuildResult { pub enum LowlevelBuildResult {
Done, Done,
BuiltinCall(&'static str), BuiltinCall(&'static str),

View file

@ -227,7 +227,7 @@ impl<'a> CodeGenHelp<'a> {
// ============================================================================ // ============================================================================
// //
// TRAVERSE LAYOUT & CREATE PROC NAMES // CREATE SPECIALIZATIONS
// //
// ============================================================================ // ============================================================================
@ -270,7 +270,9 @@ impl<'a> CodeGenHelp<'a> {
new_procs_info.push((symbol, proc_layout)); new_procs_info.push((symbol, proc_layout));
let mut visit_child = |child| { let mut visit_child = |child| {
if layout_needs_helper_proc(child, op) {
self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child); self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child);
}
}; };
let mut visit_children = |children: &'a [Layout]| { let mut visit_children = |children: &'a [Layout]| {
@ -369,82 +371,97 @@ impl<'a> CodeGenHelp<'a> {
) -> Vec<'a, Proc<'a>> { ) -> Vec<'a, Proc<'a>> {
use HelperOp::*; use HelperOp::*;
// Move the vector out of self, so we can loop over it safely // Clone the specializations so we can loop over them safely
let mut specs = std::mem::replace(&mut self.specs, Vec::with_capacity_in(0, arena)); // We need to keep self.specs for lookups of sub-procedures during generation
// Maybe could avoid this by separating specs vector from CodeGenHelp, letting backend own both.
let mut specs = self.specs.clone();
let procs_iter = specs.drain(0..).map(|(layout, op, proc_symbol)| match op { let procs_iter = specs.drain(0..).map(|(layout, op, proc_symbol)| {
Inc | Dec | DecRef => { let (ret_layout, body) = match op {
debug_assert!(Self::is_rc_implemented_yet(&layout)); Inc | Dec | DecRef => (LAYOUT_UNIT, self.refcount_generic(ident_ids, layout, op)),
let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout); Eq => (LAYOUT_BOOL, self.eq_generic(ident_ids, layout)),
};
match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!("Not refcounted: {:?}", layout),
Layout::Builtin(Builtin::Str) => {
self.gen_modify_str(ident_ids, op, proc_symbol)
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => {
rc_todo()
}
Layout::Struct(_) => rc_todo(),
Layout::Union(union_layout) => match union_layout {
UnionLayout::NonRecursive(_) => rc_todo(),
UnionLayout::Recursive(_) => rc_todo(),
UnionLayout::NonNullableUnwrapped(_) => rc_todo(),
UnionLayout::NullableWrapped { .. } => rc_todo(),
UnionLayout::NullableUnwrapped { .. } => rc_todo(),
},
Layout::LambdaSet(_) => unreachable!(
"Refcounting on LambdaSet is invalid. Should be a Union at runtime."
),
Layout::RecursivePointer => rc_todo(),
}
}
Eq => {
let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout);
match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!(
"No generated proc for `==`. Use direct code gen for {:?}",
layout
),
Layout::Builtin(Builtin::Str) => {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => {
eq_todo()
}
Layout::Struct(_) => eq_todo(),
Layout::Union(union_layout) => {
self.eq_tag_union(ident_ids, proc_symbol, union_layout)
}
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
Layout::RecursivePointer => eq_todo(),
}
}
});
Vec::from_iter_in(procs_iter, arena)
}
fn return_unit(&self, ident_ids: &mut IdentIds) -> Stmt<'a> {
let unit = self.create_symbol(ident_ids, "unit");
let ret_stmt = self.arena.alloc(Stmt::Ret(unit));
Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt)
}
fn gen_args(&self, op: HelperOp, layout: Layout<'a>) -> &'a [(Layout<'a>, Symbol)] {
let roc_value = (layout, Symbol::ARG_1); let roc_value = (layout, Symbol::ARG_1);
match op { let args: &'a [(Layout<'a>, Symbol)] = match op {
HelperOp::Inc => { HelperOp::Inc => {
let inc_amount = (self.layout_isize, Symbol::ARG_2); let inc_amount = (self.layout_isize, Symbol::ARG_2);
self.arena.alloc([roc_value, inc_amount]) self.arena.alloc([roc_value, inc_amount])
} }
HelperOp::Dec | HelperOp::DecRef => self.arena.alloc([roc_value]), HelperOp::Dec | HelperOp::DecRef => self.arena.alloc([roc_value]),
HelperOp::Eq => self.arena.alloc([roc_value, (layout, Symbol::ARG_2)]), HelperOp::Eq => self.arena.alloc([roc_value, (layout, Symbol::ARG_2)]),
};
Proc {
name: proc_symbol,
args,
body,
closure_data_layout: None,
ret_layout,
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: false,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
}
});
Vec::from_iter_in(procs_iter, arena)
}
/// Apply the HelperOp to a field of a data structure
/// Only called while generating bodies of helper procs
/// The list of specializations should be complete by this time
fn apply_op_to_sub_layout(
&mut self,
op: HelperOp,
sub_layout: &Layout<'a>,
arguments: &'a [Symbol],
) -> Expr<'a> {
let found = self
.specs
.iter()
.find(|(l, o, _)| l == sub_layout && *o == op);
if let Some((_, _, proc_name)) = found {
let arg_layouts: &[Layout<'a>] = match op {
HelperOp::Eq => self.arena.alloc([*sub_layout, *sub_layout]),
HelperOp::Inc => self.arena.alloc([*sub_layout, self.layout_isize]),
HelperOp::Dec => self.arena.alloc([*sub_layout]),
HelperOp::DecRef => unreachable!("DecRef is not recursive"),
};
let ret_layout = if matches!(op, HelperOp::Eq) {
&LAYOUT_BOOL
} else {
&LAYOUT_UNIT
};
Expr::Call(Call {
call_type: CallType::ByName {
name: *proc_name,
ret_layout,
arg_layouts,
specialization_id: CallSpecId::BACKEND_DUMMY,
},
arguments,
})
} else {
// By the time we get here (generating helper procs), the list of specializations is complete.
// So if we didn't find one, we must be at a leaf of the layout tree.
debug_assert!(!layout_needs_helper_proc(sub_layout, op));
let lowlevel = match op {
HelperOp::Eq => LowLevel::Eq,
HelperOp::Inc => LowLevel::RefCountInc,
HelperOp::Dec => LowLevel::RefCountDec,
HelperOp::DecRef => unreachable!("DecRef is not recursive"),
};
Expr::Call(Call {
call_type: CallType::LowLevel {
op: lowlevel,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments,
})
} }
} }
@ -454,13 +471,38 @@ impl<'a> CodeGenHelp<'a> {
// //
// ============================================================================ // ============================================================================
/// Generate a procedure to modify the reference count of a Str fn refcount_generic(
fn gen_modify_str(
&mut self, &mut self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
layout: Layout<'a>,
op: HelperOp, op: HelperOp,
proc_name: Symbol, ) -> Stmt<'a> {
) -> Proc<'a> { debug_assert!(Self::is_rc_implemented_yet(&layout));
let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout);
match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!("Not refcounted: {:?}", layout),
Layout::Builtin(Builtin::Str) => self.refcount_str(ident_ids, op),
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => rc_todo(),
Layout::Struct(_) => rc_todo(),
Layout::Union(_) => rc_todo(),
Layout::LambdaSet(_) => {
unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.")
}
Layout::RecursivePointer => rc_todo(),
}
}
fn return_unit(&self, ident_ids: &mut IdentIds) -> Stmt<'a> {
let unit = self.create_symbol(ident_ids, "unit");
let ret_stmt = self.arena.alloc(Stmt::Ret(unit));
Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt)
}
/// Generate a procedure to modify the reference count of a Str
fn refcount_str(&mut self, ident_ids: &mut IdentIds, op: HelperOp) -> Stmt<'a> {
let string = Symbol::ARG_1; let string = Symbol::ARG_1;
let layout_isize = self.layout_isize; let layout_isize = self.layout_isize;
@ -562,7 +604,7 @@ impl<'a> CodeGenHelp<'a> {
}; };
// Combine the statements in sequence // Combine the statements in sequence
let body = len_stmt(self.arena.alloc( len_stmt(self.arena.alloc(
// //
zero_stmt(self.arena.alloc( zero_stmt(self.arena.alloc(
// //
@ -571,20 +613,7 @@ impl<'a> CodeGenHelp<'a> {
if_stmt, if_stmt,
)), )),
)), )),
)); ))
let args = self.gen_args(op, Layout::Builtin(Builtin::Str));
Proc {
name: proc_name,
args,
body,
closure_data_layout: None,
ret_layout: LAYOUT_UNIT,
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: false,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
}
} }
// ============================================================================ // ============================================================================
@ -593,14 +622,103 @@ impl<'a> CodeGenHelp<'a> {
// //
// ============================================================================ // ============================================================================
fn eq_tag_union( fn eq_generic(&mut self, ident_ids: &mut IdentIds, layout: Layout<'a>) -> Stmt<'a> {
let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout);
let arguments = &[Symbol::ARG_1, Symbol::ARG_2];
let main_body = match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!(
"No generated proc for `==`. Use direct code gen for {:?}",
layout
),
Layout::Builtin(Builtin::Str) => {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
Layout::Struct(field_layouts) => self.eq_struct(
ident_ids,
field_layouts,
arguments,
Stmt::Ret(Symbol::BOOL_TRUE),
),
Layout::Union(_) => eq_todo(),
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
Layout::RecursivePointer => eq_todo(),
};
Stmt::Let(
Symbol::BOOL_TRUE,
Expr::Literal(Literal::Int(1)),
LAYOUT_BOOL,
self.arena.alloc(Stmt::Let(
Symbol::BOOL_FALSE,
Expr::Literal(Literal::Int(0)),
LAYOUT_BOOL,
self.arena.alloc(main_body),
)),
)
}
fn if_false_return_false(&mut self, symbol: Symbol, following: &'a Stmt<'a>) -> Stmt<'a> {
Stmt::Switch {
cond_symbol: symbol,
cond_layout: LAYOUT_BOOL,
branches: self
.arena
.alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]),
default_branch: (BranchInfo::None, following),
ret_layout: LAYOUT_BOOL,
}
}
fn eq_struct(
&mut self, &mut self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
proc_name: Symbol, field_layouts: &'a [Layout<'a>],
union_layout: UnionLayout<'a>, arguments: &'a [Symbol],
) -> Proc<'a> { following: Stmt<'a>,
) -> Stmt<'a> {
let mut stmt = following;
for (i, layout) in field_layouts.iter().enumerate().rev() {
let field1_name = format!("{:?}_field_{}", arguments[0], i);
let field1_sym = self.create_symbol(ident_ids, &field1_name);
let field1_expr = Expr::StructAtIndex {
index: i as u64,
field_layouts,
structure: arguments[0],
};
let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next);
let field2_name = format!("{:?}_field_{}", arguments[1], i);
let field2_sym = self.create_symbol(ident_ids, &field2_name);
let field2_expr = Expr::StructAtIndex {
index: i as u64,
field_layouts,
structure: arguments[1],
};
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]);
let eq_call_expr = self.apply_op_to_sub_layout(HelperOp::Eq, layout, sub_layout_args);
let eq_call_name = format!("eq_call_{}", i);
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
stmt = field1_stmt(self.arena.alloc(
// //
todo!("return something") field2_stmt(self.arena.alloc(
//
eq_call_stmt(self.arena.alloc(
//
self.if_false_return_false(eq_call_sym, self.arena.alloc(stmt)),
)),
)),
))
}
stmt
} }
} }
@ -611,13 +729,25 @@ fn layout_debug_name<'a>(layout: &Layout<'a>) -> &'static str {
Layout::Builtin(Builtin::Set(_)) => "set", Layout::Builtin(Builtin::Set(_)) => "set",
Layout::Builtin(Builtin::Dict(_, _)) => "dict", Layout::Builtin(Builtin::Dict(_, _)) => "dict",
Layout::Builtin(Builtin::Str) => "str", Layout::Builtin(Builtin::Str) => "str",
Layout::Builtin(builtin) => {
debug_assert!(!builtin.is_refcounted());
unreachable!("Builtin {:?} is not refcounted", builtin);
}
Layout::Struct(_) => "struct", Layout::Struct(_) => "struct",
Layout::Union(_) => "union", Layout::Union(_) => "union",
Layout::LambdaSet(_) => "lambdaset", Layout::LambdaSet(_) => "lambdaset",
Layout::RecursivePointer => "recursive_pointer", _ => unreachable!("Can't create helper proc name for {:?}", layout),
}
}
fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool {
match layout {
Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => {
false
}
Layout::Builtin(Builtin::Str) => {
matches!(op, HelperOp::Inc | HelperOp::Dec | HelperOp::DecRef)
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_))
| Layout::Struct(_)
| Layout::Union(_)
| Layout::LambdaSet(_)
| Layout::RecursivePointer => true,
} }
} }

View file

@ -146,12 +146,38 @@ fn neq_bool_tag() {
} }
#[test] #[test]
#[cfg(any(feature = "gen-llvm"))] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn empty_record() { fn empty_record() {
assert_evals_to!("{} == {}", true, bool); assert_evals_to!("{} == {}", true, bool);
assert_evals_to!("{} != {}", false, bool); assert_evals_to!("{} != {}", false, bool);
} }
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn record() {
assert_evals_to!(
"{ x: 123, y: \"Hello\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }",
true,
bool
);
assert_evals_to!(
"{ x: 234, y: \"Hello\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }",
false,
bool
);
assert_evals_to!(
"{ x: 123, y: \"World\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }",
false,
bool
);
assert_evals_to!(
"{ x: 123, y: \"Hello\", z: 1.11 } == { x: 123, y: \"Hello\", z: 3.14 }",
false,
bool
);
}
#[test] #[test]
#[cfg(any(feature = "gen-llvm"))] #[cfg(any(feature = "gen-llvm"))]
fn unit() { fn unit() {