Check pointer equality for Eq on Structs

This commit is contained in:
Brian Carroll 2021-12-19 09:06:06 +00:00
parent e847c924dd
commit 570044f88a

View file

@ -360,11 +360,7 @@ impl<'a> CodeGenHelp<'a> {
/// Generate refcounting helper procs, each specialized to a particular Layout. /// Generate refcounting helper procs, each specialized to a particular Layout.
/// For example `List (Result { a: Str, b: Int } Str)` would get its own helper /// For example `List (Result { a: Str, b: Int } Str)` would get its own helper
/// to update the refcounts on the List, the Result and the strings. /// to update the refcounts on the List, the Result and the strings.
pub fn generate_procs( pub fn generate_procs(&self, arena: &'a Bump, ident_ids: &mut IdentIds) -> Vec<'a, Proc<'a>> {
&mut self,
arena: &'a Bump,
ident_ids: &mut IdentIds,
) -> Vec<'a, Proc<'a>> {
use HelperOp::*; use HelperOp::*;
// Clone the specializations so we can loop over them safely // Clone the specializations so we can loop over them safely
@ -407,7 +403,7 @@ impl<'a> CodeGenHelp<'a> {
/// Only called while generating bodies of helper procs /// Only called while generating bodies of helper procs
/// The list of specializations should be complete by this time /// The list of specializations should be complete by this time
fn apply_op_to_sub_layout( fn apply_op_to_sub_layout(
&mut self, &self,
op: HelperOp, op: HelperOp,
sub_layout: &Layout<'a>, sub_layout: &Layout<'a>,
arguments: &'a [Symbol], arguments: &'a [Symbol],
@ -468,7 +464,7 @@ impl<'a> CodeGenHelp<'a> {
// ============================================================================ // ============================================================================
fn refcount_generic( fn refcount_generic(
&mut self, &self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
layout: Layout<'a>, layout: Layout<'a>,
op: HelperOp, op: HelperOp,
@ -502,7 +498,7 @@ impl<'a> CodeGenHelp<'a> {
// for the 'pointer' and 'integer' versions of the address. // for the 'pointer' and 'integer' versions of the address.
// This helps to avoid issues with the backends Symbol->Layout mapping. // This helps to avoid issues with the backends Symbol->Layout mapping.
fn rc_ptr_from_struct( fn rc_ptr_from_struct(
&mut self, &self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
structure: Symbol, structure: Symbol,
rc_ptr_sym: Symbol, rc_ptr_sym: Symbol,
@ -562,7 +558,7 @@ impl<'a> CodeGenHelp<'a> {
} }
/// Generate a procedure to modify the reference count of a Str /// Generate a procedure to modify the reference count of a Str
fn refcount_str(&mut self, ident_ids: &mut IdentIds, op: HelperOp) -> Stmt<'a> { fn refcount_str(&self, ident_ids: &mut IdentIds, op: HelperOp) -> Stmt<'a> {
let string = Symbol::ARG_1; let string = Symbol::ARG_1;
let layout_isize = self.layout_isize; let layout_isize = self.layout_isize;
@ -679,11 +675,9 @@ impl<'a> CodeGenHelp<'a> {
// //
// ============================================================================ // ============================================================================
fn eq_generic(&mut self, ident_ids: &mut IdentIds, layout: Layout<'a>) -> Stmt<'a> { fn eq_generic(&self, ident_ids: &mut IdentIds, layout: Layout<'a>) -> Stmt<'a> {
let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout); let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout);
let arguments = &[Symbol::ARG_1, Symbol::ARG_2];
let main_body = match layout { let main_body = match layout {
Layout::Builtin( Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
@ -695,12 +689,7 @@ impl<'a> CodeGenHelp<'a> {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.") unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
} }
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(), Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => eq_todo(),
Layout::Struct(field_layouts) => self.eq_struct( Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts),
ident_ids,
field_layouts,
arguments,
Stmt::Ret(Symbol::BOOL_TRUE),
),
Layout::Union(_) => eq_todo(), Layout::Union(_) => eq_todo(),
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"), Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
Layout::RecursivePointer => eq_todo(), Layout::RecursivePointer => eq_todo(),
@ -719,7 +708,64 @@ impl<'a> CodeGenHelp<'a> {
) )
} }
fn if_false_return_false(&mut self, symbol: Symbol, following: &'a Stmt<'a>) -> Stmt<'a> { fn if_pointers_equal_return_true(
&self,
ident_ids: &mut IdentIds,
ptr1: Symbol,
ptr2: Symbol,
following: &'a Stmt<'a>,
) -> Stmt<'a> {
let ptr1_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr1));
let ptr2_addr = self.create_symbol(ident_ids, &format!("{:?}_addr", ptr2));
let ptr_eq = self.create_symbol(ident_ids, &format!("eq_{:?}_{:?}", ptr1_addr, ptr2_addr));
Stmt::Let(
ptr1_addr,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: self.arena.alloc([ptr1]),
}),
self.layout_isize,
self.arena.alloc(Stmt::Let(
ptr2_addr,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: self.arena.alloc([ptr2]),
}),
self.layout_isize,
self.arena.alloc(Stmt::Let(
ptr_eq,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::Eq,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: self.arena.alloc([ptr1_addr, ptr2_addr]),
}),
LAYOUT_BOOL,
self.arena.alloc(Stmt::Switch {
cond_symbol: ptr_eq,
cond_layout: LAYOUT_BOOL,
branches: self.arena.alloc([(
1,
BranchInfo::None,
Stmt::Ret(Symbol::BOOL_TRUE),
)]),
default_branch: (BranchInfo::None, following),
ret_layout: LAYOUT_BOOL,
}),
)),
)),
)
}
fn if_false_return_false(&self, symbol: Symbol, following: &'a Stmt<'a>) -> Stmt<'a> {
Stmt::Switch { Stmt::Switch {
cond_symbol: symbol, cond_symbol: symbol,
cond_layout: LAYOUT_BOOL, cond_layout: LAYOUT_BOOL,
@ -731,8 +777,23 @@ impl<'a> CodeGenHelp<'a> {
} }
} }
fn eq_struct( fn eq_struct(&self, ident_ids: &mut IdentIds, field_layouts: &'a [Layout<'a>]) -> Stmt<'a> {
&mut self, let else_clause = self.eq_fields(
ident_ids,
field_layouts,
&[Symbol::ARG_1, Symbol::ARG_2],
Stmt::Ret(Symbol::BOOL_TRUE),
);
self.if_pointers_equal_return_true(
ident_ids,
Symbol::ARG_1,
Symbol::ARG_2,
self.arena.alloc(else_clause),
)
}
fn eq_fields(
&self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
field_layouts: &'a [Layout<'a>], field_layouts: &'a [Layout<'a>],
arguments: &'a [Symbol], arguments: &'a [Symbol],