Restructure CodeGenHelp to generate IR immediately, in depth-first traversal

This commit is contained in:
Brian Carroll 2021-12-23 15:26:48 +00:00
parent f314abfed9
commit ca501fdcf1
3 changed files with 360 additions and 374 deletions

View file

@ -247,7 +247,7 @@ fn build_object<'a, B: Backend<'a>>(
let (env, interns, helper_proc_gen) = backend.env_interns_helpers_mut(); let (env, interns, helper_proc_gen) = backend.env_interns_helpers_mut();
let ident_ids = interns.all_ident_ids.get_mut(&module_id).unwrap(); let ident_ids = interns.all_ident_ids.get_mut(&module_id).unwrap();
let helper_procs = helper_proc_gen.generate_procs(arena, ident_ids); let helper_procs = helper_proc_gen.take_procs();
env.module_id.register_debug_idents(ident_ids); env.module_id.register_debug_idents(ident_ids);
helper_procs helper_procs

View file

@ -160,14 +160,7 @@ impl<'a> WasmBackend<'a> {
} }
pub fn generate_helpers(&mut self) -> Vec<'a, Proc<'a>> { pub fn generate_helpers(&mut self) -> Vec<'a, Proc<'a>> {
let ident_ids = self self.helper_proc_gen.take_procs()
.interns
.all_ident_ids
.get_mut(&self.env.module_id)
.unwrap();
self.helper_proc_gen
.generate_procs(self.env.arena, ident_ids)
} }
fn register_helper_proc(&mut self, new_proc_info: (Symbol, ProcLayout<'a>)) { fn register_helper_proc(&mut self, new_proc_info: (Symbol, ProcLayout<'a>)) {
@ -1060,7 +1053,9 @@ impl<'a> WasmBackend<'a> {
use LowLevel::*; use LowLevel::*;
match lowlevel { match lowlevel {
Eq | NotEq => self.build_eq(lowlevel, arguments, return_sym, return_layout, storage), Eq | NotEq => {
self.build_eq_or_neq(lowlevel, arguments, return_sym, return_layout, storage)
}
PtrCast => { PtrCast => {
// Don't want Zig calling convention when casting pointers. // Don't want Zig calling convention when casting pointers.
self.storage.load_symbols(&mut self.code_builder, arguments); self.storage.load_symbols(&mut self.code_builder, arguments);
@ -1103,7 +1098,7 @@ impl<'a> WasmBackend<'a> {
} }
} }
fn build_eq( fn build_eq_or_neq(
&mut self, &mut self,
lowlevel: LowLevel, lowlevel: LowLevel,
arguments: &'a [Symbol], arguments: &'a [Symbol],
@ -1121,7 +1116,7 @@ impl<'a> WasmBackend<'a> {
match arg_layout { match arg_layout {
Layout::Builtin( Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => self.build_eq_number(lowlevel, arguments, return_layout), ) => self.build_eq_or_neq_number(lowlevel, arguments, return_layout),
Layout::Builtin(Builtin::Str) => { Layout::Builtin(Builtin::Str) => {
let (param_types, ret_type) = self.storage.load_symbols_for_call( let (param_types, ret_type) = self.storage.load_symbols_for_call(
@ -1161,7 +1156,7 @@ impl<'a> WasmBackend<'a> {
} }
} }
fn build_eq_number( fn build_eq_or_neq_number(
&mut self, &mut self,
lowlevel: LowLevel, lowlevel: LowLevel,
arguments: &'a [Symbol], arguments: &'a [Symbol],
@ -1295,7 +1290,12 @@ impl<'a> WasmBackend<'a> {
// Generate Wasm code for the IR call expression // Generate Wasm code for the IR call expression
let bool_layout = Layout::Builtin(Builtin::Bool); let bool_layout = Layout::Builtin(Builtin::Bool);
self.build_expr(&return_sym, specialized_call_expr, &bool_layout, storage); self.build_expr(
&return_sym,
self.env.arena.alloc(specialized_call_expr),
&bool_layout,
storage,
);
} }
fn load_literal( fn load_literal(

View file

@ -23,8 +23,8 @@ const ARG_2: Symbol = Symbol::ARG_2;
/// Ref counts are encoded as negative numbers where isize::MIN represents 1 /// Ref counts are encoded as negative numbers where isize::MIN represents 1
pub const REFCOUNT_MAX: usize = 0; pub const REFCOUNT_MAX: usize = 0;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum HelperOp { enum HelperOp {
Inc, Inc,
Dec, Dec,
DecRef, DecRef,
@ -41,6 +41,19 @@ impl From<&ModifyRc> for HelperOp {
} }
} }
struct SpecializedProc<'a> {
op: HelperOp,
layout: Layout<'a>,
proc: Proc<'a>,
}
#[derive(Debug)]
struct Context<'a> {
new_linker_data: Vec<'a, (Symbol, ProcLayout<'a>)>,
rec_ptr_layout: Option<UnionLayout<'a>>,
op: HelperOp,
}
/// Generate specialized helper procs for code gen /// Generate specialized helper procs for code gen
/// ---------------------------------------------- /// ----------------------------------------------
/// ///
@ -64,9 +77,7 @@ pub struct CodeGenHelp<'a> {
home: ModuleId, home: ModuleId,
ptr_size: u32, ptr_size: u32,
layout_isize: Layout<'a>, layout_isize: Layout<'a>,
/// Specializations to generate specialized_procs: Vec<'a, SpecializedProc<'a>>,
/// Order of insertion is preserved, since it is important for Wasm backend
specs: Vec<'a, (Layout<'a>, HelperOp, Symbol)>,
} }
impl<'a> CodeGenHelp<'a> { impl<'a> CodeGenHelp<'a> {
@ -76,10 +87,19 @@ impl<'a> CodeGenHelp<'a> {
home, home,
ptr_size: intwidth_isize.stack_size(), ptr_size: intwidth_isize.stack_size(),
layout_isize: Layout::Builtin(Builtin::Int(intwidth_isize)), layout_isize: Layout::Builtin(Builtin::Int(intwidth_isize)),
specs: Vec::with_capacity_in(16, arena), specialized_procs: Vec::with_capacity_in(16, arena),
} }
} }
pub fn take_procs(&mut self) -> Vec<'a, Proc<'a>> {
let procs_iter = self
.specialized_procs
.drain(0..)
.map(|SpecializedProc { proc, .. }| proc);
Vec::from_iter_in(procs_iter, self.arena)
}
// ============================================================================ // ============================================================================
// //
// CALL GENERATED PROCS // CALL GENERATED PROCS
@ -107,51 +127,44 @@ impl<'a> CodeGenHelp<'a> {
let arena = self.arena; let arena = self.arena;
let mut ctx = Context {
new_linker_data: Vec::new_in(self.arena),
rec_ptr_layout: None,
op: HelperOp::from(modify),
};
match modify { match modify {
ModifyRc::Inc(structure, amount) => { ModifyRc::Inc(structure, amount) => {
let layout_isize = self.layout_isize; let layout_isize = self.layout_isize;
let (proc_name, new_procs_info) =
self.get_or_create_proc_symbols_recursive(ident_ids, &layout, HelperOp::Inc);
// Define a constant for the amount to increment // Define a constant for the amount to increment
let amount_sym = self.create_symbol(ident_ids, "amount"); let amount_sym = self.create_symbol(ident_ids, "amount");
let amount_expr = Expr::Literal(Literal::Int(*amount as i128)); let amount_expr = Expr::Literal(Literal::Int(*amount as i128));
let amount_stmt = |next| Stmt::Let(amount_sym, amount_expr, layout_isize, next); let amount_stmt = |next| Stmt::Let(amount_sym, amount_expr, layout_isize, next);
// Call helper proc, passing the Roc structure and constant amount // Call helper proc, passing the Roc structure and constant amount
let arg_layouts = arena.alloc([layout, layout_isize]);
let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); let call_result_empty = self.create_symbol(ident_ids, "call_result_empty");
let call_expr = Expr::Call(Call { let call_expr = self.call_specialized_op(
call_type: CallType::ByName { ident_ids,
name: proc_name, &mut ctx,
ret_layout: &LAYOUT_UNIT, layout,
arg_layouts, arena.alloc([*structure, amount_sym]),
specialization_id: CallSpecId::BACKEND_DUMMY, );
},
arguments: arena.alloc([*structure, amount_sym]),
});
let call_stmt = Stmt::Let(call_result_empty, call_expr, LAYOUT_UNIT, following); let call_stmt = Stmt::Let(call_result_empty, call_expr, LAYOUT_UNIT, following);
let rc_stmt = arena.alloc(amount_stmt(arena.alloc(call_stmt))); let rc_stmt = arena.alloc(amount_stmt(arena.alloc(call_stmt)));
(rc_stmt, new_procs_info) (rc_stmt, ctx.new_linker_data)
} }
ModifyRc::Dec(structure) => { ModifyRc::Dec(structure) => {
let (proc_name, new_procs_info) =
self.get_or_create_proc_symbols_recursive(ident_ids, &layout, HelperOp::Dec);
// Call helper proc, passing the Roc structure // Call helper proc, passing the Roc structure
let call_result_empty = self.create_symbol(ident_ids, "call_result_empty"); let call_result_empty = self.create_symbol(ident_ids, "call_result_empty");
let call_expr = Expr::Call(Call { let call_expr = self.call_specialized_op(
call_type: CallType::ByName { ident_ids,
name: proc_name, &mut ctx,
ret_layout: &LAYOUT_UNIT, layout,
arg_layouts: arena.alloc([layout]), arena.alloc([*structure]),
specialization_id: CallSpecId::BACKEND_DUMMY, );
},
arguments: arena.alloc([*structure]),
});
let rc_stmt = arena.alloc(Stmt::Let( let rc_stmt = arena.alloc(Stmt::Let(
call_result_empty, call_result_empty,
@ -160,7 +173,7 @@ impl<'a> CodeGenHelp<'a> {
following, following,
)); ));
(rc_stmt, new_procs_info) (rc_stmt, ctx.new_linker_data)
} }
ModifyRc::DecRef(structure) => { ModifyRc::DecRef(structure) => {
@ -185,7 +198,7 @@ impl<'a> CodeGenHelp<'a> {
arena.alloc(call_stmt), arena.alloc(call_stmt),
)); ));
(rc_stmt, Vec::new_in(self.arena)) (rc_stmt, ctx.new_linker_data)
} }
} }
} }
@ -204,234 +217,49 @@ impl<'a> CodeGenHelp<'a> {
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
layout: &Layout<'a>, layout: &Layout<'a>,
arguments: &'a [Symbol], arguments: &'a [Symbol],
) -> (&'a Expr<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) { ) -> (Expr<'a>, Vec<'a, (Symbol, ProcLayout<'a>)>) {
// Record a specialization and get its name let mut ctx = Context {
let (proc_name, new_procs_info) = new_linker_data: Vec::new_in(self.arena),
self.get_or_create_proc_symbols_recursive(ident_ids, layout, HelperOp::Eq); rec_ptr_layout: None,
op: HelperOp::Eq,
};
// Call the specialized helper let expr = self.call_specialized_op(ident_ids, &mut ctx, *layout, arguments);
let arg_layouts = self.arena.alloc([*layout, *layout]);
let expr = self.arena.alloc(Expr::Call(Call {
call_type: CallType::ByName {
name: proc_name,
ret_layout: &LAYOUT_BOOL,
arg_layouts,
specialization_id: CallSpecId::BACKEND_DUMMY,
},
arguments,
}));
(expr, new_procs_info) dbg!(&ctx);
(expr, ctx.new_linker_data)
} }
// ============================================================================ // ============================================================================
// //
// CREATE SPECIALIZATIONS // CALL SPECIALIZED OP
// //
// ============================================================================ // ============================================================================
/// Find the Symbol of the procedure for this layout and operation fn call_specialized_op(
/// If any new helper procs are needed for this layout or its children,
/// return their details in a vector.
fn get_or_create_proc_symbols_recursive(
&mut self, &mut self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
layout: &Layout<'a>, ctx: &mut Context<'a>,
op: HelperOp, layout: Layout<'a>,
) -> (Symbol, Vec<'a, (Symbol, ProcLayout<'a>)>) {
let mut new_procs_info = Vec::new_in(self.arena);
let proc_symbol =
self.get_or_create_proc_symbols_visit(ident_ids, &mut new_procs_info, op, layout);
(proc_symbol, new_procs_info)
}
fn get_or_create_proc_symbols_visit(
&mut self,
ident_ids: &mut IdentIds,
new_procs_info: &mut Vec<'a, (Symbol, ProcLayout<'a>)>,
op: HelperOp,
layout: &Layout<'a>,
) -> Symbol {
if let Layout::LambdaSet(lambda_set) = layout {
return self.get_or_create_proc_symbols_visit(
ident_ids,
new_procs_info,
op,
&lambda_set.runtime_representation(),
);
}
let (symbol, new_proc_layout) = self.get_or_create_proc_symbol(ident_ids, layout, op);
if let Some(proc_layout) = new_proc_layout {
new_procs_info.push((symbol, proc_layout));
let mut visit_child = |child| {
if layout_needs_helper_proc(child, op) {
self.get_or_create_proc_symbols_visit(ident_ids, new_procs_info, op, child);
}
};
let mut visit_children = |children: &'a [Layout]| {
for child in children {
visit_child(child);
}
};
let mut visit_tags = |tags: &'a [&'a [Layout]]| {
for tag in tags {
visit_children(tag);
}
};
match layout {
Layout::Builtin(builtin) => match builtin {
Builtin::Dict(key, value) => {
visit_child(key);
visit_child(value);
}
Builtin::Set(element) | Builtin::List(element) => visit_child(element),
_ => {}
},
Layout::Struct(fields) => visit_children(fields),
Layout::Union(union_layout) => match union_layout {
UnionLayout::NonRecursive(tags) => visit_tags(tags),
UnionLayout::Recursive(tags) => visit_tags(tags),
UnionLayout::NonNullableUnwrapped(fields) => visit_children(fields),
UnionLayout::NullableWrapped { other_tags, .. } => visit_tags(other_tags),
UnionLayout::NullableUnwrapped { other_fields, .. } => {
visit_children(other_fields)
}
},
Layout::LambdaSet(_) => unreachable!(),
Layout::RecursivePointer => {}
}
}
symbol
}
fn get_or_create_proc_symbol(
&mut self,
ident_ids: &mut IdentIds,
layout: &Layout<'a>,
op: HelperOp,
) -> (Symbol, Option<ProcLayout<'a>>) {
let found = self.specs.iter().find(|(l, o, _)| l == layout && *o == op);
if let Some((_, _, existing_symbol)) = found {
(*existing_symbol, None)
} else {
let layout_name = layout_debug_name(layout);
let debug_name = format!("#help{:?}_{}_{}", op, layout_name, self.specs.len());
let new_symbol: Symbol = self.create_symbol(ident_ids, &debug_name);
self.specs.push((*layout, op, new_symbol));
let new_proc_layout = match op {
HelperOp::Inc => Some(ProcLayout {
arguments: self.arena.alloc([*layout, self.layout_isize]),
result: LAYOUT_UNIT,
}),
HelperOp::Dec => Some(ProcLayout {
arguments: self.arena.alloc([*layout]),
result: LAYOUT_UNIT,
}),
HelperOp::DecRef => None,
HelperOp::Eq => Some(ProcLayout {
arguments: self.arena.alloc([*layout, *layout]),
result: LAYOUT_BOOL,
}),
};
(new_symbol, new_proc_layout)
}
}
fn create_symbol(&self, ident_ids: &mut IdentIds, debug_name: &str) -> Symbol {
let ident_id = ident_ids.add(Ident::from(debug_name));
Symbol::new(self.home, ident_id)
}
// ============================================================================
//
// GENERATE PROCS
//
// ============================================================================
/// Generate refcounting helper procs, each specialized to a particular Layout.
/// For example `List (Result { a: Str, b: Int } Str)` would get its own helper
/// to update the refcounts on the List, the Result and the strings.
pub fn generate_procs(&self, arena: &'a Bump, ident_ids: &mut IdentIds) -> Vec<'a, Proc<'a>> {
use HelperOp::*;
// Clone the specializations so we can loop over them safely
// We need to keep self.specs for lookups of sub-procedures during generation
// Maybe could avoid this by separating specs vector from CodeGenHelp, letting backend own both.
let mut specs = self.specs.clone();
let procs_iter = specs.drain(0..).map(|(layout, op, proc_symbol)| {
let (ret_layout, body) = match op {
Inc | Dec | DecRef => (LAYOUT_UNIT, self.refcount_generic(ident_ids, layout, op)),
Eq => (LAYOUT_BOOL, self.eq_generic(ident_ids, layout)),
};
let roc_value = (layout, ARG_1);
let args: &'a [(Layout<'a>, Symbol)] = match op {
HelperOp::Inc => {
let inc_amount = (self.layout_isize, ARG_2);
self.arena.alloc([roc_value, inc_amount])
}
HelperOp::Dec | HelperOp::DecRef => self.arena.alloc([roc_value]),
HelperOp::Eq => self.arena.alloc([roc_value, (layout, ARG_2)]),
};
Proc {
name: proc_symbol,
args,
body,
closure_data_layout: None,
ret_layout,
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: false,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
}
});
Vec::from_iter_in(procs_iter, arena)
}
/// Apply the HelperOp to a field of a data structure
/// Only called while generating bodies of helper procs
/// The list of specializations should be complete by this time
fn apply_op_to_sub_layout(
&self,
op: HelperOp,
sub_layout: &Layout<'a>,
arguments: &[Symbol], arguments: &[Symbol],
) -> Expr<'a> { ) -> Expr<'a> {
let found = self use HelperOp::*;
.specs
.iter()
.find(|(l, o, _)| l == sub_layout && *o == op);
if let Some((_, _, proc_name)) = found { if layout_needs_helper_proc(&layout, ctx.op) {
let arg_layouts: &[Layout<'a>] = match op { let proc_name = self.find_or_create_proc(ident_ids, ctx, layout);
HelperOp::Eq => self.arena.alloc([*sub_layout, *sub_layout]),
HelperOp::Inc => self.arena.alloc([*sub_layout, self.layout_isize]), let (ret_layout, arg_layouts): (&'a Layout<'a>, &'a [Layout<'a>]) = {
HelperOp::Dec => self.arena.alloc([*sub_layout]), match ctx.op {
HelperOp::DecRef => unreachable!("DecRef is not recursive"), Dec | DecRef => (&LAYOUT_UNIT, self.arena.alloc([layout])),
}; Inc => (&LAYOUT_UNIT, self.arena.alloc([layout, self.layout_isize])),
let ret_layout = if matches!(op, HelperOp::Eq) { Eq => (&LAYOUT_BOOL, self.arena.alloc([layout, layout])),
&LAYOUT_BOOL }
} else {
&LAYOUT_UNIT
}; };
Expr::Call(Call { Expr::Call(Call {
call_type: CallType::ByName { call_type: CallType::ByName {
name: *proc_name, name: proc_name,
ret_layout, ret_layout,
arg_layouts, arg_layouts,
specialization_id: CallSpecId::BACKEND_DUMMY, specialization_id: CallSpecId::BACKEND_DUMMY,
@ -439,20 +267,9 @@ impl<'a> CodeGenHelp<'a> {
arguments: self.arena.alloc_slice_copy(arguments), arguments: self.arena.alloc_slice_copy(arguments),
}) })
} else { } else {
// By the time we get here (generating helper procs), the list of specializations is complete.
// So if we didn't find one, we must be at a leaf of the layout tree.
debug_assert!(!layout_needs_helper_proc(sub_layout, op));
let lowlevel = match op {
HelperOp::Eq => LowLevel::Eq,
HelperOp::Inc => LowLevel::RefCountInc,
HelperOp::Dec => LowLevel::RefCountDec,
HelperOp::DecRef => unreachable!("DecRef is not recursive"),
};
Expr::Call(Call { Expr::Call(Call {
call_type: CallType::LowLevel { call_type: CallType::LowLevel {
op: lowlevel, op: LowLevel::Eq,
update_mode: UpdateModeId::BACKEND_DUMMY, update_mode: UpdateModeId::BACKEND_DUMMY,
}, },
arguments: self.arena.alloc_slice_copy(arguments), arguments: self.arena.alloc_slice_copy(arguments),
@ -460,6 +277,105 @@ impl<'a> CodeGenHelp<'a> {
} }
} }
fn find_or_create_proc(
&mut self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout: Layout<'a>,
) -> Symbol {
use HelperOp::*;
let mut new_procs_info = Vec::new_in(self.arena);
let found = self
.specialized_procs
.iter()
.find(|spec| spec.op == ctx.op && spec.layout == layout);
if let Some(spec) = found {
return spec.proc.name;
}
let (proc_symbol, proc_layout) = self.create_proc_symbol(ident_ids, ctx, &layout);
new_procs_info.push((proc_symbol, proc_layout));
// Generate the body of the Proc
let (ret_layout, body) = match ctx.op {
Inc | Dec | DecRef => (LAYOUT_UNIT, self.refcount_generic(ident_ids, ctx, layout)),
Eq => (LAYOUT_BOOL, self.eq_generic(ident_ids, ctx, layout)),
};
let args: &'a [(Layout<'a>, Symbol)] = {
let roc_value = (layout, ARG_1);
match ctx.op {
Inc => {
let inc_amount = (self.layout_isize, ARG_2);
self.arena.alloc([roc_value, inc_amount])
}
Dec | DecRef => self.arena.alloc([roc_value]),
Eq => self.arena.alloc([roc_value, (layout, ARG_2)]),
}
};
let proc = Proc {
name: proc_symbol,
args,
body,
closure_data_layout: None,
ret_layout,
is_self_recursive: SelfRecursive::NotSelfRecursive,
must_own_arguments: false,
host_exposed_layouts: HostExposedLayouts::NotHostExposed,
};
self.specialized_procs.push(SpecializedProc {
op: ctx.op,
layout,
proc,
});
proc_symbol
}
fn create_proc_symbol(
&self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout: &Layout<'a>,
) -> (Symbol, Option<ProcLayout<'a>>) {
let layout_name = layout_debug_name(layout);
let debug_name = format!(
"#help{:?}_{}_{}",
ctx.op,
layout_name,
self.specialized_procs.len()
);
let proc_symbol: Symbol = self.create_symbol(ident_ids, &debug_name);
let proc_layout = match ctx.op {
HelperOp::Inc => Some(ProcLayout {
arguments: self.arena.alloc([*layout, self.layout_isize]),
result: LAYOUT_UNIT,
}),
HelperOp::Dec => Some(ProcLayout {
arguments: self.arena.alloc([*layout]),
result: LAYOUT_UNIT,
}),
HelperOp::DecRef => None,
HelperOp::Eq => Some(ProcLayout {
arguments: self.arena.alloc([*layout, *layout]),
result: LAYOUT_BOOL,
}),
};
(proc_symbol, proc_layout)
}
fn create_symbol(&self, ident_ids: &mut IdentIds, debug_name: &str) -> Symbol {
let ident_id = ident_ids.add(Ident::from(debug_name));
Symbol::new(self.home, ident_id)
}
// ============================================================================ // ============================================================================
// //
// GENERATE REFCOUNTING // GENERATE REFCOUNTING
@ -469,8 +385,8 @@ impl<'a> CodeGenHelp<'a> {
fn refcount_generic( fn refcount_generic(
&self, &self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout: Layout<'a>, layout: Layout<'a>,
op: HelperOp,
) -> Stmt<'a> { ) -> Stmt<'a> {
debug_assert!(Self::is_rc_implemented_yet(&layout)); debug_assert!(Self::is_rc_implemented_yet(&layout));
let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout); let rc_todo = || todo!("Please update is_rc_implemented_yet for `{:?}`", layout);
@ -479,7 +395,7 @@ impl<'a> CodeGenHelp<'a> {
Layout::Builtin( Layout::Builtin(
Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal, Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal,
) => unreachable!("Not refcounted: {:?}", layout), ) => unreachable!("Not refcounted: {:?}", layout),
Layout::Builtin(Builtin::Str) => self.refcount_str(ident_ids, op), Layout::Builtin(Builtin::Str) => self.refcount_str(ident_ids, ctx),
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => rc_todo(), Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_) | Builtin::List(_)) => rc_todo(),
Layout::Struct(_) => rc_todo(), Layout::Struct(_) => rc_todo(),
Layout::Union(_) => rc_todo(), Layout::Union(_) => rc_todo(),
@ -561,7 +477,9 @@ impl<'a> CodeGenHelp<'a> {
} }
/// Generate a procedure to modify the reference count of a Str /// Generate a procedure to modify the reference count of a Str
fn refcount_str(&self, ident_ids: &mut IdentIds, op: HelperOp) -> Stmt<'a> { fn refcount_str(&self, ident_ids: &mut IdentIds, ctx: &mut Context<'a>) -> Stmt<'a> {
let op = ctx.op;
let string = ARG_1; let string = ARG_1;
let layout_isize = self.layout_isize; let layout_isize = self.layout_isize;
@ -678,7 +596,12 @@ impl<'a> CodeGenHelp<'a> {
// //
// ============================================================================ // ============================================================================
fn eq_generic(&self, ident_ids: &mut IdentIds, layout: Layout<'a>) -> Stmt<'a> { fn eq_generic(
&mut self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
layout: Layout<'a>,
) -> Stmt<'a> {
let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout); let eq_todo = || todo!("Specialized `==` operator for `{:?}`", layout);
let main_body = match layout { let main_body = match layout {
@ -692,9 +615,11 @@ impl<'a> CodeGenHelp<'a> {
unreachable!("No generated helper proc for `==` on Str. Use Zig function.") unreachable!("No generated helper proc for `==` on Str. Use Zig function.")
} }
Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(), Layout::Builtin(Builtin::Dict(_, _) | Builtin::Set(_)) => eq_todo(),
Layout::Builtin(Builtin::List(elem_layout)) => self.eq_list(ident_ids, elem_layout), Layout::Builtin(Builtin::List(elem_layout)) => {
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, field_layouts), self.eq_list(ident_ids, ctx, elem_layout)
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, union_layout), }
Layout::Struct(field_layouts) => self.eq_struct(ident_ids, ctx, field_layouts),
Layout::Union(union_layout) => self.eq_tag_union(ident_ids, ctx, union_layout),
Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"), Layout::LambdaSet(_) => unreachable!("`==` is not defined on functions"),
Layout::RecursivePointer => eq_todo(), Layout::RecursivePointer => eq_todo(),
}; };
@ -779,17 +704,22 @@ impl<'a> CodeGenHelp<'a> {
} }
} }
fn eq_struct(&self, ident_ids: &mut IdentIds, field_layouts: &'a [Layout<'a>]) -> Stmt<'a> { fn eq_struct(
let else_clause = self.eq_fields(ident_ids, 0, field_layouts, None); &mut self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
field_layouts: &'a [Layout<'a>],
) -> Stmt<'a> {
let else_clause = self.eq_fields(ident_ids, ctx, 0, field_layouts);
self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(else_clause)) self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(else_clause))
} }
fn eq_fields( fn eq_fields(
&self, &mut self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
tag_id: u64, tag_id: u64,
field_layouts: &'a [Layout<'a>], field_layouts: &'a [Layout<'a>],
rec_ptr_layout: Option<Layout<'a>>,
) -> Stmt<'a> { ) -> Stmt<'a> {
let mut stmt = Stmt::Ret(Symbol::BOOL_TRUE); let mut stmt = Stmt::Ret(Symbol::BOOL_TRUE);
for (i, layout) in field_layouts.iter().enumerate().rev() { for (i, layout) in field_layouts.iter().enumerate().rev() {
@ -809,13 +739,13 @@ impl<'a> CodeGenHelp<'a> {
}; };
let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next); let field2_stmt = |next| Stmt::Let(field2_sym, field2_expr, *layout, next);
let sub_layout = match (layout, rec_ptr_layout) { let eq_call_expr = self.call_specialized_op(
(Layout::RecursivePointer, Some(rec_layout)) => self.arena.alloc(rec_layout), ident_ids,
_ => layout, ctx,
}; *layout,
self.arena.alloc([field1_sym, field2_sym]),
);
let eq_call_expr =
self.apply_op_to_sub_layout(HelperOp::Eq, sub_layout, &[field1_sym, field2_sym]);
let eq_call_name = format!("eq_call_{}", i); let eq_call_name = format!("eq_call_{}", i);
let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name); let eq_call_sym = self.create_symbol(ident_ids, &eq_call_name);
let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next); let eq_call_stmt = |next| Stmt::Let(eq_call_sym, eq_call_expr, LAYOUT_BOOL, next);
@ -834,43 +764,54 @@ impl<'a> CodeGenHelp<'a> {
stmt stmt
} }
fn eq_tag_union(&self, ident_ids: &mut IdentIds, union_layout: UnionLayout<'a>) -> Stmt<'a> { fn eq_tag_union(
&mut self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>,
) -> Stmt<'a> {
use UnionLayout::*; use UnionLayout::*;
let parent_rec_ptr_layout = ctx.rec_ptr_layout;
if !matches!(union_layout, NonRecursive(_)) {
ctx.rec_ptr_layout = Some(union_layout);
}
let main_stmt = match union_layout { let main_stmt = match union_layout {
NonRecursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None), NonRecursive(tags) => self.eq_tag_union_help(ident_ids, ctx, union_layout, tags, None),
Recursive(tags) => self.eq_tag_union_help(ident_ids, union_layout, tags, None), Recursive(tags) => self.eq_tag_union_help(ident_ids, ctx, union_layout, tags, None),
NonNullableUnwrapped(field_layouts) => self.eq_fields( NonNullableUnwrapped(field_layouts) => self.eq_fields(ident_ids, ctx, 0, field_layouts),
ident_ids,
0,
field_layouts,
Some(Layout::Union(union_layout)),
),
NullableWrapped { NullableWrapped {
other_tags, other_tags,
nullable_id, nullable_id,
} => self.eq_tag_union_help(ident_ids, union_layout, other_tags, Some(nullable_id)), } => {
self.eq_tag_union_help(ident_ids, ctx, union_layout, other_tags, Some(nullable_id))
}
NullableUnwrapped { NullableUnwrapped {
other_fields, other_fields,
nullable_id: n, nullable_id: n,
} => self.eq_tag_union_help( } => self.eq_tag_union_help(
ident_ids, ident_ids,
ctx,
union_layout, union_layout,
self.arena.alloc([other_fields]), self.arena.alloc([other_fields]),
Some(n as u16), Some(n as u16),
), ),
}; };
ctx.rec_ptr_layout = parent_rec_ptr_layout;
self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(main_stmt)) self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(main_stmt))
} }
fn eq_tag_union_help( fn eq_tag_union_help(
&self, &mut self,
ident_ids: &mut IdentIds, ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
union_layout: UnionLayout<'a>, union_layout: UnionLayout<'a>,
tag_layouts: &'a [&'a [Layout<'a>]], tag_layouts: &'a [&'a [Layout<'a>]],
nullable_id: Option<u16>, nullable_id: Option<u16>,
@ -904,30 +845,18 @@ impl<'a> CodeGenHelp<'a> {
}; };
let tag_ids_eq = self.create_symbol(ident_ids, "tag_ids_eq"); let tag_ids_eq = self.create_symbol(ident_ids, "tag_ids_eq");
let tag_ids_eq_stmt = |next| { let tag_ids_expr = Expr::Call(Call {
Stmt::Let( call_type: CallType::LowLevel {
tag_ids_eq, op: LowLevel::Eq,
Expr::Call(Call { update_mode: UpdateModeId::BACKEND_DUMMY,
call_type: CallType::LowLevel { },
op: LowLevel::Eq, arguments: self.arena.alloc([tag_id_a, tag_id_b]),
update_mode: UpdateModeId::BACKEND_DUMMY, });
}, let tag_ids_eq_stmt = |next| Stmt::Let(tag_ids_eq, tag_ids_expr, LAYOUT_BOOL, next);
arguments: self.arena.alloc([tag_id_a, tag_id_b]),
}),
LAYOUT_BOOL,
next,
)
};
let if_equal_ids_stmt = |next| Stmt::Switch { let if_equal_ids_branches =
cond_symbol: tag_ids_eq, self.arena
cond_layout: LAYOUT_BOOL, .alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]);
branches: self
.arena
.alloc([(0, BranchInfo::None, Stmt::Ret(Symbol::BOOL_FALSE))]),
default_branch: (BranchInfo::None, next),
ret_layout: LAYOUT_BOOL,
};
// //
// Switch statement by tag ID // Switch statement by tag ID
@ -940,8 +869,6 @@ impl<'a> CodeGenHelp<'a> {
tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE))) tag_branches.push((id as u64, BranchInfo::None, Stmt::Ret(Symbol::BOOL_TRUE)))
} }
let recursive_ptr_layout = Some(Layout::Union(union_layout));
let mut tag_id: u64 = 0; let mut tag_id: u64 = 0;
for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) { for field_layouts in tag_layouts.iter().take(tag_layouts.len() - 1) {
if let Some(null_id) = nullable_id { if let Some(null_id) = nullable_id {
@ -950,11 +877,8 @@ impl<'a> CodeGenHelp<'a> {
} }
} }
tag_branches.push(( let tag_stmt = self.eq_fields(ident_ids, ctx, tag_id, field_layouts);
tag_id, tag_branches.push((tag_id, BranchInfo::None, tag_stmt));
BranchInfo::None,
self.eq_fields(ident_ids, tag_id, field_layouts, recursive_ptr_layout),
));
tag_id += 1; tag_id += 1;
} }
@ -967,14 +891,22 @@ impl<'a> CodeGenHelp<'a> {
BranchInfo::None, BranchInfo::None,
self.arena.alloc(self.eq_fields( self.arena.alloc(self.eq_fields(
ident_ids, ident_ids,
ctx,
tag_id, tag_id,
tag_layouts.last().unwrap(), tag_layouts.last().unwrap(),
recursive_ptr_layout,
)), )),
), ),
ret_layout: LAYOUT_BOOL, ret_layout: LAYOUT_BOOL,
}; };
let if_equal_ids_stmt = Stmt::Switch {
cond_symbol: tag_ids_eq,
cond_layout: LAYOUT_BOOL,
branches: if_equal_ids_branches,
default_branch: (BranchInfo::None, self.arena.alloc(tag_switch_stmt)),
ret_layout: LAYOUT_BOOL,
};
// //
// combine all the statments // combine all the statments
// //
@ -984,10 +916,7 @@ impl<'a> CodeGenHelp<'a> {
// //
tag_ids_eq_stmt(self.arena.alloc( tag_ids_eq_stmt(self.arena.alloc(
// //
if_equal_ids_stmt(self.arena.alloc( if_equal_ids_stmt,
//
tag_switch_stmt,
)),
)), )),
)), )),
)) ))
@ -999,9 +928,15 @@ impl<'a> CodeGenHelp<'a> {
/// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout. /// To achieve this we use `PtrCast` to cast the element pointer to a "Box" layout.
/// Then we can increment the Box pointer in a loop, dereferencing it each time. /// Then we can increment the Box pointer in a loop, dereferencing it each time.
/// (An alternative approach would be to create a new lowlevel like ListPeekUnsafe.) /// (An alternative approach would be to create a new lowlevel like ListPeekUnsafe.)
fn eq_list(&self, ident_ids: &mut IdentIds, elem_layout: &Layout<'a>) -> Stmt<'a> { fn eq_list(
&mut self,
ident_ids: &mut IdentIds,
ctx: &mut Context<'a>,
elem_layout: &Layout<'a>,
) -> Stmt<'a> {
use LowLevel::*; use LowLevel::*;
let layout_isize = self.layout_isize; let layout_isize = self.layout_isize;
let arena = self.arena;
// A "Box" layout (heap pointer to a single list element) // A "Box" layout (heap pointer to a single list element)
let box_union_layout = UnionLayout::NonNullableUnwrapped(self.arena.alloc([*elem_layout])); let box_union_layout = UnionLayout::NonNullableUnwrapped(self.arena.alloc([*elem_layout]));
@ -1011,11 +946,12 @@ impl<'a> CodeGenHelp<'a> {
let len_1 = self.create_symbol(ident_ids, "len_1"); let len_1 = self.create_symbol(ident_ids, "len_1");
let len_2 = self.create_symbol(ident_ids, "len_2"); let len_2 = self.create_symbol(ident_ids, "len_2");
let len_1_stmt = |next| self.let_lowlevel(layout_isize, len_1, ListLen, &[ARG_1], next); let len_1_stmt = |next| let_lowlevel(arena, layout_isize, len_1, ListLen, &[ARG_1], next);
let len_2_stmt = |next| self.let_lowlevel(layout_isize, len_2, ListLen, &[ARG_2], next); let len_2_stmt = |next| let_lowlevel(arena, layout_isize, len_2, ListLen, &[ARG_2], next);
let eq_len = self.create_symbol(ident_ids, "eq_len"); let eq_len = self.create_symbol(ident_ids, "eq_len");
let eq_len_stmt = |next| self.let_lowlevel(LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next); let eq_len_stmt =
|next| let_lowlevel(arena, LAYOUT_BOOL, eq_len, Eq, &[len_1, len_2], next);
// if lengths are equal... // if lengths are equal...
@ -1038,10 +974,26 @@ impl<'a> CodeGenHelp<'a> {
// Cast to integers // Cast to integers
let start_addr_1 = self.create_symbol(ident_ids, "start_addr_1"); let start_addr_1 = self.create_symbol(ident_ids, "start_addr_1");
let start_addr_2 = self.create_symbol(ident_ids, "start_addr_2"); let start_addr_2 = self.create_symbol(ident_ids, "start_addr_2");
let start_addr_1_stmt = let start_addr_1_stmt = |next| {
|next| self.let_lowlevel(layout_isize, start_addr_1, PtrCast, &[elements_1], next); let_lowlevel(
let start_addr_2_stmt = arena,
|next| self.let_lowlevel(layout_isize, start_addr_2, PtrCast, &[elements_2], next); layout_isize,
start_addr_1,
PtrCast,
&[elements_1],
next,
)
};
let start_addr_2_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
start_addr_2,
PtrCast,
&[elements_2],
next,
)
};
// //
// Loop initialisation // Loop initialisation
@ -1055,13 +1007,22 @@ impl<'a> CodeGenHelp<'a> {
// let list_size = len_1 * elem_size // let list_size = len_1 * elem_size
let list_size = self.create_symbol(ident_ids, "list_size"); let list_size = self.create_symbol(ident_ids, "list_size");
let list_size_stmt = let list_size_stmt = |next| {
|next| self.let_lowlevel(layout_isize, list_size, NumMul, &[len_1, elem_size], next); let_lowlevel(
arena,
layout_isize,
list_size,
NumMul,
&[len_1, elem_size],
next,
)
};
// let end_addr_1 = start_addr_1 + list_size // let end_addr_1 = start_addr_1 + list_size
let end_addr_1 = self.create_symbol(ident_ids, "end_addr_1"); let end_addr_1 = self.create_symbol(ident_ids, "end_addr_1");
let end_addr_1_stmt = |next| { let end_addr_1_stmt = |next| {
self.let_lowlevel( let_lowlevel(
arena,
layout_isize, layout_isize,
end_addr_1, end_addr_1,
NumAdd, NumAdd,
@ -1097,8 +1058,8 @@ impl<'a> CodeGenHelp<'a> {
// Cast integers to box pointers // Cast integers to box pointers
let box1 = self.create_symbol(ident_ids, "box1"); let box1 = self.create_symbol(ident_ids, "box1");
let box2 = self.create_symbol(ident_ids, "box2"); let box2 = self.create_symbol(ident_ids, "box2");
let box1_stmt = |next| self.let_lowlevel(box_layout, box1, PtrCast, &[addr1], next); let box1_stmt = |next| let_lowlevel(arena, box_layout, box1, PtrCast, &[addr1], next);
let box2_stmt = |next| self.let_lowlevel(box_layout, box2, PtrCast, &[addr2], next); let box2_stmt = |next| let_lowlevel(arena, box_layout, box2, PtrCast, &[addr2], next);
// Dereference the box pointers to get the current elements // Dereference the box pointers to get the current elements
let elem1 = self.create_symbol(ident_ids, "elem1"); let elem1 = self.create_symbol(ident_ids, "elem1");
@ -1120,16 +1081,33 @@ impl<'a> CodeGenHelp<'a> {
// Compare the two current elements // Compare the two current elements
let eq_elems = self.create_symbol(ident_ids, "eq_elems"); let eq_elems = self.create_symbol(ident_ids, "eq_elems");
let eq_elems_expr = self.apply_op_to_sub_layout(HelperOp::Eq, elem_layout, &[elem1, elem2]); let eq_elems_expr = self.call_specialized_op(ident_ids, ctx, *elem_layout, &[elem1, elem2]);
let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next); let eq_elems_stmt = |next| Stmt::Let(eq_elems, eq_elems_expr, LAYOUT_BOOL, next);
// If current elements are equal, loop back again // If current elements are equal, loop back again
let next_addr_1 = self.create_symbol(ident_ids, "next_addr_1"); let next_addr_1 = self.create_symbol(ident_ids, "next_addr_1");
let next_addr_2 = self.create_symbol(ident_ids, "next_addr_2"); let next_addr_2 = self.create_symbol(ident_ids, "next_addr_2");
let next_addr_1_stmt = let next_addr_1_stmt = |next| {
|next| self.let_lowlevel(layout_isize, next_addr_1, NumAdd, &[addr1, elem_size], next); let_lowlevel(
let next_addr_2_stmt = arena,
|next| self.let_lowlevel(layout_isize, next_addr_2, NumAdd, &[addr2, elem_size], next); layout_isize,
next_addr_1,
NumAdd,
&[addr1, elem_size],
next,
)
};
let next_addr_2_stmt = |next| {
let_lowlevel(
arena,
layout_isize,
next_addr_2,
NumAdd,
&[addr2, elem_size],
next,
)
};
let jump_back = Stmt::Jump(elems_loop, self.arena.alloc([next_addr_1, next_addr_2])); let jump_back = Stmt::Jump(elems_loop, self.arena.alloc([next_addr_1, next_addr_2]));
@ -1138,8 +1116,16 @@ impl<'a> CodeGenHelp<'a> {
// //
let is_end = self.create_symbol(ident_ids, "is_end"); let is_end = self.create_symbol(ident_ids, "is_end");
let is_end_stmt = let is_end_stmt = |next| {
|next| self.let_lowlevel(LAYOUT_BOOL, is_end, NumGte, &[addr1, end_addr_1], next); let_lowlevel(
arena,
LAYOUT_BOOL,
is_end,
NumGte,
&[addr1, end_addr_1],
next,
)
};
let if_elems_not_equal = self.if_false_return_false( let if_elems_not_equal = self.if_false_return_false(
eq_elems, eq_elems,
@ -1246,28 +1232,28 @@ impl<'a> CodeGenHelp<'a> {
self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(pointers_else)) self.if_pointers_equal_return_true(ident_ids, self.arena.alloc(pointers_else))
} }
}
fn let_lowlevel( fn let_lowlevel<'a>(
&self, arena: &'a Bump,
result_layout: Layout<'a>, result_layout: Layout<'a>,
result: Symbol, result: Symbol,
op: LowLevel, op: LowLevel,
args: &[Symbol], arguments: &[Symbol],
next: &'a Stmt<'a>, next: &'a Stmt<'a>,
) -> Stmt<'a> { ) -> Stmt<'a> {
Stmt::Let( Stmt::Let(
result, result,
Expr::Call(Call { Expr::Call(Call {
call_type: CallType::LowLevel { call_type: CallType::LowLevel {
op, op,
update_mode: UpdateModeId::BACKEND_DUMMY, update_mode: UpdateModeId::BACKEND_DUMMY,
}, },
arguments: self.arena.alloc_slice_copy(args), arguments: arena.alloc_slice_copy(arguments),
}), }),
result_layout, result_layout,
next, next,
) )
}
} }
/// Helper to derive a debug function name from a layout /// Helper to derive a debug function name from a layout