Equality operator for records

This commit is contained in:
Brian Carroll 2021-12-18 17:25:00 +00:00
parent 55f5956175
commit a1d883600c
4 changed files with 267 additions and 99 deletions

View file

@ -1049,9 +1049,19 @@ impl<'a> WasmBackend<'a> {
}
SpecializedEq | SpecializedNotEq => {
let layout = self.symbol_layouts[&arguments[0]];
let layout_rhs = self.symbol_layouts[&arguments[1]];
debug_assert!(
layout == layout_rhs,
"Cannot do `==` comparison on different types"
);
if layout == Layout::Builtin(Builtin::Str) {
self.call_zig_builtin(bitcode::STR_EQUAL, param_types, ret_type);
} else if layout.stack_size(PTR_SIZE) == 0 {
// If the layout has zero size, and it type-checks, the values must be equal
let value = matches!(build_result, SpecializedEq);
self.code_builder.i32_const(value as i32);
return;
} else {
let ident_ids = self
.interns
@ -1068,7 +1078,8 @@ impl<'a> WasmBackend<'a> {
self.register_helper_proc(spec);
}
self.build_expr(&return_sym, replacement_expr, &layout, storage);
let bool_layout = Layout::Builtin(Builtin::Bool);
self.build_expr(&return_sym, replacement_expr, &bool_layout, storage);
}
if matches!(build_result, SpecializedNotEq) {

View file

@ -7,6 +7,7 @@ use crate::layout::{StackMemoryFormat::*, WasmLayout};
use crate::storage::{Storage, StoredValue};
use crate::wasm_module::{Align, CodeBuilder, ValueType::*};
#[derive(Debug)]
pub enum LowlevelBuildResult {
Done,
BuiltinCall(&'static str),

View file

@ -227,7 +227,7 @@ impl<'a> CodeGenHelp<'a> {
// ============================================================================
//
// TRAVERSE LAYOUT & CREATE PROC NAMES
// CREATE SPECIALIZATIONS
//
// ============================================================================
@ -270,7 +270,9 @@ impl<'a> CodeGenHelp<'a> {
new_procs_info.push((symbol, proc_layout));
let mut visit_child = |child| {
if layout_needs_helper_proc(child, op) {
self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child);
}
};
let mut visit_children = |children: &'a [Layout]| {
@ -369,82 +371,97 @@ impl<'a> CodeGenHelp<'a> {
) -> Vec<'a, Proc<'a>> {
use HelperOp::*;
// Move the vector out of self, so we can loop over it safely
let mut specs = std::mem::replace(&mut self.specs, Vec::with_capacity_in(0, arena));
// Clone the specializations so we can loop over them safely
// We need to keep self.specs for lookups of sub-procedures during generation
// Maybe could avoid this by separating specs vector from CodeGenHelp, letting backend own both.
let mut specs = self.specs.clone();
let procs_iter = specs.drain(0..).map(|(layout, op, proc_symbol)| match op {
Inc | Dec | DecRef => {
debug_assert!(Self::is_rc_implemented_yet(&layout));
let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout);
let procs_iter = specs.drain(0..).map(|(layout, op, proc_symbol)| {
let (ret_layout, body) = match op {
Inc | Dec | DecRef => (LAYOUT_UNIT, self.refcount_generic(ident_ids, layout, op)),
Eq => (LAYOUT_BOOL, self.eq_generic(ident_ids, layout)),
};
match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!("Not refcounted: {:?}", layout),
Layout::Builtin(Builtin::Str) => {
self.gen_modify_str(ident_ids, op, proc_symbol)
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => {
rc_todo()
}
Layout::Struct(_) => rc_todo(),
Layout::Union(union_layout) => match union_layout {
UnionLayout::NonRecursive(_) => rc_todo(),
UnionLayout::Recursive(_) => rc_todo(),
UnionLayout::NonNullableUnwrapped(_) => rc_todo(),
UnionLayout::NullableWrapped { .. } => rc_todo(),
UnionLayout::NullableUnwrapped { .. } => rc_todo(),
},
Layout::LambdaSet(_) => unreachable!(
"Refcounting on LambdaSet is invalid. Should be a Union at runtime."
),
Layout::RecursivePointer => rc_todo(),
}
}
Eq => {
let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout);
match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!(
"No generated proc for `==`. Use direct code gen for {:?}",
layout
),
Layout::Builtin(Builtin::Str) => {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => {
eq_todo()
}
Layout::Struct(_) => eq_todo(),
Layout::Union(union_layout) => {
self.eq_tag_union(ident_ids, proc_symbol, union_layout)
}
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
Layout::RecursivePointer => eq_todo(),
}
}
});
Vec::from_iter_in(procs_iter, arena)
}
fn return_unit(&self, ident_ids: &mut IdentIds) -> Stmt<'a> {
let unit = self.create_symbol(ident_ids, "unit");
let ret_stmt = self.arena.alloc(Stmt::Ret(unit));
Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt)
}
fn gen_args(&self, op: HelperOp, layout: Layout<'a>) -> &'a [(Layout<'a>, Symbol)] {
let roc_value = (layout, Symbol::ARG_1);
match op {
let args: &'a [(Layout<'a>, Symbol)] = match op {
HelperOp::Inc => {
let inc_amount = (self.layout_isize, Symbol::ARG_2);
self.arena.alloc([roc_value, inc_amount])
}
HelperOp::Dec | HelperOp::DecRef => self.arena.alloc([roc_value]),
HelperOp::Eq => self.arena.alloc([roc_value, (layout, Symbol::ARG_2)]),
};
Proc {
name: proc_symbol,
args,
body,
closure_data_layout: None,
ret_layout,
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: false,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
}
});
Vec::from_iter_in(procs_iter, arena)
}
/// Apply the HelperOp to a field of a data structure
/// Only called while generating bodies of helper procs
/// The list of specializations should be complete by this time
fn apply_op_to_sub_layout(
&mut self,
op: HelperOp,
sub_layout: &Layout<'a>,
arguments: &'a [Symbol],
) -> Expr<'a> {
let found = self
.specs
.iter()
.find(|(l, o, _)| l == sub_layout && *o == op);
if let Some((_, _, proc_name)) = found {
let arg_layouts: &[Layout<'a>] = match op {
HelperOp::Eq => self.arena.alloc([*sub_layout, *sub_layout]),
HelperOp::Inc => self.arena.alloc([*sub_layout, self.layout_isize]),
HelperOp::Dec => self.arena.alloc([*sub_layout]),
HelperOp::DecRef => unreachable!("DecRef is not recursive"),
};
let ret_layout = if matches!(op, HelperOp::Eq) {
&LAYOUT_BOOL
} else {
&LAYOUT_UNIT
};
Expr::Call(Call {
call_type: CallType::ByName {
name: *proc_name,
ret_layout,
arg_layouts,
specialization_id: CallSpecId::BACKEND_DUMMY,
},
arguments,
})
} else {
// By the time we get here (generating helper procs), the list of specializations is complete.
// So if we didn't find one, we must be at a leaf of the layout tree.
debug_assert!(!layout_needs_helper_proc(sub_layout, op));
let lowlevel = match op {
HelperOp::Eq => LowLevel::Eq,
HelperOp::Inc => LowLevel::RefCountInc,
HelperOp::Dec => LowLevel::RefCountDec,
HelperOp::DecRef => unreachable!("DecRef is not recursive"),
};
Expr::Call(Call {
call_type: CallType::LowLevel {
op: lowlevel,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments,
})
}
}
@ -454,13 +471,38 @@ impl<'a> CodeGenHelp<'a> {
//
// ============================================================================
/// Generate a procedure to modify the reference count of a Str
fn gen_modify_str(
fn refcount_generic(
&mut self,
ident_ids: &mut IdentIds,
layout: Layout<'a>,
op: HelperOp,
proc_name: Symbol,
) -> Proc<'a> {
) -> Stmt<'a> {
debug_assert!(Self::is_rc_implemented_yet(&layout));
let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout);
match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!("Not refcounted: {:?}", layout),
Layout::Builtin(Builtin::Str) => self.refcount_str(ident_ids, op),
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => rc_todo(),
Layout::Struct(_) => rc_todo(),
Layout::Union(_) => rc_todo(),
Layout::LambdaSet(_) => {
unreachable!("Refcounting on LambdaSet is invalid. Should be a Union at runtime.")
}
Layout::RecursivePointer => rc_todo(),
}
}
fn return_unit(&self, ident_ids: &mut IdentIds) -> Stmt<'a> {
let unit = self.create_symbol(ident_ids, "unit");
let ret_stmt = self.arena.alloc(Stmt::Ret(unit));
Stmt::Let(unit, Expr::Struct(&[]), LAYOUT_UNIT, ret_stmt)
}
/// Generate a procedure to modify the reference count of a Str
fn refcount_str(&mut self, ident_ids: &mut IdentIds, op: HelperOp) -> Stmt<'a> {
let string = Symbol::ARG_1;
let layout_isize = self.layout_isize;
@ -562,7 +604,7 @@ impl<'a> CodeGenHelp<'a> {
};
// Combine the statements in sequence
let body = len_stmt(self.arena.alloc(
len_stmt(self.arena.alloc(
//
zero_stmt(self.arena.alloc(
//
@ -571,20 +613,7 @@ impl<'a> CodeGenHelp<'a> {
if_stmt,
)),
)),
));
let args = self.gen_args(op, Layout::Builtin(Builtin::Str));
Proc {
name: proc_name,
args,
body,
closure_data_layout: None,
ret_layout: LAYOUT_UNIT,
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: false,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
}
))
}
// ============================================================================
@ -593,14 +622,103 @@ impl<'a> CodeGenHelp<'a> {
//
// ============================================================================
fn eq_tag_union(
fn eq_generic(&mut self, ident_ids: &mut IdentIds, layout: Layout<'a>) -> Stmt<'a> {
let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout);
let arguments = &[Symbol::ARG_1, Symbol::ARG_2];
let main_body = match layout {
Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!(
"No generated proc for `==`. Use direct code gen for {:?}",
layout
),
Layout::Builtin(Builtin::Str) => {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
Layout::Struct(field_layouts) => self.eq_struct(
ident_ids,
field_layouts,
arguments,
Stmt::Ret(Symbol::BOOL_TRUE),
),
Layout::Union(_) => eq_todo(),
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
Layout::RecursivePointer => eq_todo(),
};
Stmt::Let(
Symbol::BOOL_TRUE,
Expr::Literal(Literal::Int(1)),
LAYOUT_BOOL,
self.arena.alloc(Stmt::Let(
Symbol::BOOL_FALSE,
Expr::Literal(Literal::Int(0)),
LAYOUT_BOOL,
self.arena.alloc(main_body),
)),
)
}
fn if_false_return_false(&mut self, symbol: Symbol, following: &'a Stmt<'a>) -> Stmt<'a> {
Stmt::Switch {
cond_symbol: symbol,
cond_layout: LAYOUT_BOOL,
branches: self
.arena
.alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]),
default_branch: (BranchInfo::None, following),
ret_layout: LAYOUT_BOOL,
}
}
fn eq_struct(
&mut self,
ident_ids: &mut IdentIds,
proc_name: Symbol,
union_layout: UnionLayout<'a>,
) -> Proc<'a> {
field_layouts: &'a [Layout<'a>],
arguments: &'a [Symbol],
following: Stmt<'a>,
) -> Stmt<'a> {
let mut stmt = following;
for (i, layout) in field_layouts.iter().enumerate().rev() {
let field1_name = format!("{:?}_field_{}", arguments[0], i);
let field1_sym = self.create_symbol(ident_ids, &field1_name);
let field1_expr = Expr::StructAtIndex {
index: i as u64,
field_layouts,
structure: arguments[0],
};
let field1_stmt = |next| Stmt::Let(field1_sym, field1_expr, *layout, next);
let field2_name = format!("{:?}_field_{}", arguments[1], i);
let field2_sym = self.create_symbol(ident_ids, &field2_name);
let field2_expr = Expr::StructAtIndex {
index: i as u64,
field_layouts,
structure: arguments[1],
};
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
let sub_layout_args = self.arena.alloc([field1_sym, field2_sym]);
let eq_call_expr = self.apply_op_to_sub_layout(HelperOp::Eq, layout, sub_layout_args);
let eq_call_name = format!("eq_call_{}", i);
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
stmt = field1_stmt(self.arena.alloc(
//
todo!("return something")
field2_stmt(self.arena.alloc(
//
eq_call_stmt(self.arena.alloc(
//
self.if_false_return_false(eq_call_sym, self.arena.alloc(stmt)),
)),
)),
))
}
stmt
}
}
@ -611,13 +729,25 @@ fn layout_debug_name<'a>(layout: &Layout<'a>) -> &'static str {
Layout::Builtin(Builtin::Set(_)) => "set",
Layout::Builtin(Builtin::Dict(_, _)) => "dict",
Layout::Builtin(Builtin::Str) => "str",
Layout::Builtin(builtin) => {
debug_assert!(!builtin.is_refcounted());
unreachable!("Builtin {:?} is not refcounted", builtin);
}
Layout::Struct(_) => "struct",
Layout::Union(_) => "union",
Layout::LambdaSet(_) => "lambdaset",
Layout::RecursivePointer => "recursive_pointer",
_ => unreachable!("Can't create helper proc name for {:?}", layout),
}
}
fn layout_needs_helper_proc(layout: &Layout, op: HelperOp) -> bool {
match layout {
Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => {
false
}
Layout::Builtin(Builtin::Str) => {
matches!(op, HelperOp::Inc | HelperOp::Dec | HelperOp::DecRef)
}
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_))
| Layout::Struct(_)
| Layout::Union(_)
| Layout::LambdaSet(_)
| Layout::RecursivePointer => true,
}
}

View file

@ -146,12 +146,38 @@ fn neq_bool_tag() {
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn empty_record() {
assert_evals_to!("{} == {}", true, bool);
assert_evals_to!("{} != {}", false, bool);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn record() {
assert_evals_to!(
"{ x: 123, y: \"Hello\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }",
true,
bool
);
assert_evals_to!(
"{ x: 234, y: \"Hello\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }",
false,
bool
);
assert_evals_to!(
"{ x: 123, y: \"World\", z: 3.14 } == { x: 123, y: \"Hello\", z: 3.14 }",
false,
bool
);
assert_evals_to!(
"{ x: 123, y: \"Hello\", z: 1.11 } == { x: 123, y: \"Hello\", z: 3.14 }",
false,
bool
);
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
fn unit() {