Add join points and tail call optimization to the dev backend.

This commit is contained in:
Brendan Hansknecht 2021-09-20 23:13:30 -07:00
parent 946502b8ea
commit 16d098da5e
6 changed files with 375 additions and 124 deletions

View file

@ -2,10 +2,9 @@ use crate::{Backend, Env, Relocation};
use bumpalo::collections::Vec;
use roc_collections::all::{MutMap, MutSet};
use roc_module::symbol::Symbol;
use roc_mono::ir::{BranchInfo, Literal, Stmt};
use roc_mono::ir::{BranchInfo, JoinPointId, Literal, Param, SelfRecursive, Stmt};
use roc_mono::layout::{Builtin, Layout};
use std::marker::PhantomData;
use target_lexicon::Triple;
pub mod aarch64;
pub mod x86_64;
@ -211,12 +210,16 @@ pub struct Backend64Bit<
env: &'a Env<'a>,
buf: Vec<'a, u8>,
relocs: Vec<'a, Relocation>,
proc_name: Option<String>,
is_self_recursive: Option<SelfRecursive>,
last_seen_map: MutMap<Symbol, *const Stmt<'a>>,
layout_map: MutMap<Symbol, *const Layout<'a>>,
free_map: MutMap<*const Stmt<'a>, Vec<'a, Symbol>>,
symbol_storage_map: MutMap<Symbol, SymbolStorage<GeneralReg, FloatReg>>,
literal_map: MutMap<Symbol, Literal<'a>>,
join_map: MutMap<JoinPointId, u64>,
// This should probably be smarter than a vec.
// There are certain registers we should always use first. With pushing and popping, this could get mixed.
@ -247,11 +250,13 @@ impl<
CC: CallConv<GeneralReg, FloatReg>,
> Backend<'a> for Backend64Bit<'a, GeneralReg, FloatReg, ASM, CC>
{
fn new(env: &'a Env, _target: &Triple) -> Result<Self, String> {
fn new(env: &'a Env) -> Result<Self, String> {
Ok(Backend64Bit {
phantom_asm: PhantomData,
phantom_cc: PhantomData,
env,
proc_name: None,
is_self_recursive: None,
buf: bumpalo::vec![in env.arena],
relocs: bumpalo::vec![in env.arena],
last_seen_map: MutMap::default(),
@ -259,6 +264,7 @@ impl<
free_map: MutMap::default(),
symbol_storage_map: MutMap::default(),
literal_map: MutMap::default(),
join_map: MutMap::default(),
general_free_regs: bumpalo::vec![in env.arena],
general_used_regs: bumpalo::vec![in env.arena],
general_used_callee_saved_regs: MutSet::default(),
@ -275,12 +281,15 @@ impl<
self.env
}
fn reset(&mut self) {
fn reset(&mut self, name: String, is_self_recursive: SelfRecursive) {
self.proc_name = Some(name);
self.is_self_recursive = Some(is_self_recursive);
self.stack_size = 0;
self.free_stack_chunks.clear();
self.fn_call_stack_size = 0;
self.last_seen_map.clear();
self.layout_map.clear();
self.join_map.clear();
self.free_map.clear();
self.symbol_storage_map.clear();
self.buf.clear();
@ -330,6 +339,19 @@ impl<
)?;
let setup_offset = out.len();
// Deal with jumps to the return address.
let ret_offset = self.buf.len();
let old_relocs = std::mem::replace(&mut self.relocs, bumpalo::vec![in self.env.arena]);
let mut tmp = bumpalo::vec![in self.env.arena];
for reloc in old_relocs
.iter()
.filter(|reloc| matches!(reloc, Relocation::JmpToReturn { .. }))
{
if let Relocation::JmpToReturn { inst_loc, offset } = reloc {
self.update_jmp_imm32_offset(&mut tmp, *inst_loc, *offset, ret_offset as u64);
}
}
// Add function body.
out.extend(&self.buf);
@ -342,23 +364,28 @@ impl<
)?;
ASM::ret(&mut out);
// Update relocs to include stack setup offset.
// Update other relocs to include stack setup offset.
let mut out_relocs = bumpalo::vec![in self.env.arena];
let old_relocs = std::mem::replace(&mut self.relocs, bumpalo::vec![in self.env.arena]);
out_relocs.extend(old_relocs.into_iter().map(|reloc| match reloc {
Relocation::LocalData { offset, data } => Relocation::LocalData {
offset: offset + setup_offset as u64,
data,
},
Relocation::LinkedData { offset, name } => Relocation::LinkedData {
offset: offset + setup_offset as u64,
name,
},
Relocation::LinkedFunction { offset, name } => Relocation::LinkedFunction {
offset: offset + setup_offset as u64,
name,
},
}));
out_relocs.extend(
old_relocs
.into_iter()
.filter(|reloc| !matches!(reloc, Relocation::JmpToReturn { .. }))
.map(|reloc| match reloc {
Relocation::LocalData { offset, data } => Relocation::LocalData {
offset: offset + setup_offset as u64,
data,
},
Relocation::LinkedData { offset, name } => Relocation::LinkedData {
offset: offset + setup_offset as u64,
name,
},
Relocation::LinkedFunction { offset, name } => Relocation::LinkedFunction {
offset: offset + setup_offset as u64,
name,
},
Relocation::JmpToReturn { .. } => unreachable!(),
}),
);
Ok((out.into_bump_slice(), out_relocs.into_bump_slice()))
}
@ -401,29 +428,13 @@ impl<
arg_layouts: &[Layout<'a>],
ret_layout: &Layout<'a>,
) -> Result<(), String> {
if let Some(SelfRecursive::SelfRecursive(id)) = self.is_self_recursive {
if &fn_name == self.proc_name.as_ref().unwrap() && self.join_map.contains_key(&id) {
return self.build_jump(&id, args, arg_layouts, ret_layout);
}
}
// Save used caller saved regs.
let old_general_used_regs = std::mem::replace(
&mut self.general_used_regs,
bumpalo::vec![in self.env.arena],
);
for (reg, saved_sym) in old_general_used_regs.into_iter() {
if CC::general_caller_saved(&reg) {
self.general_free_regs.push(reg);
self.free_to_stack(&saved_sym)?;
} else {
self.general_used_regs.push((reg, saved_sym));
}
}
let old_float_used_regs =
std::mem::replace(&mut self.float_used_regs, bumpalo::vec![in self.env.arena]);
for (reg, saved_sym) in old_float_used_regs.into_iter() {
if CC::float_caller_saved(&reg) {
self.float_free_regs.push(reg);
self.free_to_stack(&saved_sym)?;
} else {
self.float_used_regs.push((reg, saved_sym));
}
}
self.push_used_caller_saved_regs_to_stack()?;
// Put values in param regs or on top of the stack.
let tmp_stack_size = CC::store_args(
@ -486,7 +497,7 @@ impl<
// Build unconditional jump to the end of this switch.
// Since we don't know the offset yet, set it to 0 and overwrite later.
let jmp_location = self.buf.len();
let jmp_offset = ASM::jmp_imm32(&mut self.buf, 0);
let jmp_offset = ASM::jmp_imm32(&mut self.buf, 0x1234_5678);
ret_jumps.push((jmp_location, jmp_offset));
// Overwite the original jne with the correct offset.
@ -510,12 +521,12 @@ impl<
// Update all return jumps to jump past the default case.
let ret_offset = self.buf.len();
for (jmp_location, start_offset) in ret_jumps.into_iter() {
tmp.clear();
let jmp_offset = ret_offset - start_offset;
ASM::jmp_imm32(&mut tmp, jmp_offset as i32);
for (i, byte) in tmp.iter().enumerate() {
self.buf[jmp_location + i] = *byte;
}
self.update_jmp_imm32_offset(
&mut tmp,
jmp_location as u64,
start_offset as u64,
ret_offset as u64,
);
}
Ok(())
} else {
@ -526,6 +537,135 @@ impl<
}
}
fn build_join(
&mut self,
id: &JoinPointId,
parameters: &'a [Param<'a>],
body: &'a Stmt<'a>,
remainder: &'a Stmt<'a>,
ret_layout: &Layout<'a>,
) -> Result<(), String> {
for param in parameters {
if param.borrow {
return Err("Join: borrowed parameters not yet supported".to_string());
}
}
// Create jump to remaining.
let jmp_location = self.buf.len();
let start_offset = ASM::jmp_imm32(&mut self.buf, 0x1234_5678);
// This section can essentially be seen as a sub function within the main function.
// Thus we build using a new backend with some minor extra syncronization.
let mut sub_backend = Self::new(self.env)?;
sub_backend.reset(
self.proc_name.as_ref().unwrap().clone(),
self.is_self_recursive.as_ref().unwrap().clone(),
);
// Sync static maps of important information.
sub_backend.last_seen_map = self.last_seen_map.clone();
sub_backend.layout_map = self.layout_map.clone();
sub_backend.free_map = self.free_map.clone();
// Setup join point.
sub_backend.join_map.insert(*id, 0);
self.join_map.insert(*id, self.buf.len() as u64);
// Sync stack size so the "sub function" doesn't mess up our stack.
sub_backend.stack_size = self.stack_size;
sub_backend.fn_call_stack_size = self.fn_call_stack_size;
// Load params as if they were args.
let mut args = bumpalo::vec![in self.env.arena];
for param in parameters {
args.push((param.layout, param.symbol));
}
sub_backend.load_args(args.into_bump_slice(), ret_layout)?;
// Build all statements in body.
sub_backend.build_stmt(body, ret_layout)?;
// Merge the "sub function" into the main function.
let sub_func_offset = self.buf.len() as u64;
self.buf.extend_from_slice(&sub_backend.buf);
// Update stack based on how much was used by the sub function.
self.stack_size = sub_backend.stack_size;
self.fn_call_stack_size = sub_backend.fn_call_stack_size;
// Relocations must be shifted to be merged correctly.
self.relocs
.extend(sub_backend.relocs.into_iter().map(|reloc| match reloc {
Relocation::LocalData { offset, data } => Relocation::LocalData {
offset: offset + sub_func_offset,
data,
},
Relocation::LinkedData { offset, name } => Relocation::LinkedData {
offset: offset + sub_func_offset,
name,
},
Relocation::LinkedFunction { offset, name } => Relocation::LinkedFunction {
offset: offset + sub_func_offset,
name,
},
Relocation::JmpToReturn { inst_loc, offset } => Relocation::JmpToReturn {
inst_loc: inst_loc + sub_func_offset,
offset: offset + sub_func_offset,
},
}));
// Overwite the original jump with the correct offset.
let mut tmp = bumpalo::vec![in self.env.arena];
self.update_jmp_imm32_offset(
&mut tmp,
jmp_location as u64,
start_offset as u64,
self.buf.len() as u64,
);
// Build remainder of function.
self.build_stmt(remainder, ret_layout)
}
fn build_jump(
&mut self,
id: &JoinPointId,
args: &'a [Symbol],
arg_layouts: &[Layout<'a>],
ret_layout: &Layout<'a>,
) -> Result<(), String> {
// Treat this like a function call, but with a jump install of a call instruction at the end.
self.push_used_caller_saved_regs_to_stack()?;
let tmp_stack_size = CC::store_args(
&mut self.buf,
&self.symbol_storage_map,
args,
arg_layouts,
ret_layout,
)?;
self.fn_call_stack_size = std::cmp::max(self.fn_call_stack_size, tmp_stack_size);
let jmp_location = self.buf.len();
let start_offset = ASM::jmp_imm32(&mut self.buf, 0x1234_5678);
if let Some(offset) = self.join_map.get(id) {
let offset = *offset;
let mut tmp = bumpalo::vec![in self.env.arena];
self.update_jmp_imm32_offset(
&mut tmp,
jmp_location as u64,
start_offset as u64,
offset,
);
Ok(())
} else {
Err(format!(
"Jump: unknown point specified to jump to: {:?}",
id
))
}
}
fn build_num_abs(
&mut self,
dst: &Symbol,
@ -828,29 +968,26 @@ impl<
fn return_symbol(&mut self, sym: &Symbol, layout: &Layout<'a>) -> Result<(), String> {
let val = self.symbol_storage_map.get(sym);
match val {
Some(SymbolStorage::GeneralReg(reg)) if *reg == CC::GENERAL_RETURN_REGS[0] => Ok(()),
Some(SymbolStorage::GeneralReg(reg)) if *reg == CC::GENERAL_RETURN_REGS[0] => {}
Some(SymbolStorage::GeneralReg(reg)) => {
// If it fits in a general purpose register, just copy it over to.
// Technically this can be optimized to produce shorter instructions if less than 64bits.
ASM::mov_reg64_reg64(&mut self.buf, CC::GENERAL_RETURN_REGS[0], *reg);
Ok(())
}
Some(SymbolStorage::FloatReg(reg)) if *reg == CC::FLOAT_RETURN_REGS[0] => Ok(()),
Some(SymbolStorage::FloatReg(reg)) if *reg == CC::FLOAT_RETURN_REGS[0] => {}
Some(SymbolStorage::FloatReg(reg)) => {
ASM::mov_freg64_freg64(&mut self.buf, CC::FLOAT_RETURN_REGS[0], *reg);
Ok(())
}
Some(SymbolStorage::Base { offset, size, .. }) => match layout {
Layout::Builtin(Builtin::Int64) => {
ASM::mov_reg64_base32(&mut self.buf, CC::GENERAL_RETURN_REGS[0], *offset);
Ok(())
}
Layout::Builtin(Builtin::Float64) => {
ASM::mov_freg64_base32(&mut self.buf, CC::FLOAT_RETURN_REGS[0], *offset);
Ok(())
}
Layout::Struct(field_layouts) => {
let (offset, size) = (*offset, *size);
// Nothing to do for empty struct
if size > 0 {
let ret_reg = if self.symbol_storage_map.contains_key(&Symbol::RET_POINTER)
{
@ -858,23 +995,31 @@ impl<
} else {
None
};
CC::return_struct(&mut self.buf, offset, size, field_layouts, ret_reg)
} else {
// Nothing to do for empty struct
Ok(())
CC::return_struct(&mut self.buf, offset, size, field_layouts, ret_reg)?;
}
}
x => Err(format!(
"returning symbol with layout, {:?}, is not yet implemented",
x
)),
x => {
return Err(format!(
"returning symbol with layout, {:?}, is not yet implemented",
x
));
}
},
Some(x) => Err(format!(
"returning symbol storage, {:?}, is not yet implemented",
x
)),
None => Err(format!("Unknown return symbol: {}", sym)),
Some(x) => {
return Err(format!(
"returning symbol storage, {:?}, is not yet implemented",
x
));
}
None => {
return Err(format!("Unknown return symbol: {}", sym));
}
}
let inst_loc = self.buf.len() as u64;
let offset = ASM::jmp_imm32(&mut self.buf, 0x1234_5678) as u64;
self.relocs
.push(Relocation::JmpToReturn { inst_loc, offset });
Ok(())
}
}
@ -1212,4 +1357,45 @@ impl<
)),
}
}
fn push_used_caller_saved_regs_to_stack(&mut self) -> Result<(), String> {
let old_general_used_regs = std::mem::replace(
&mut self.general_used_regs,
bumpalo::vec![in self.env.arena],
);
for (reg, saved_sym) in old_general_used_regs.into_iter() {
if CC::general_caller_saved(&reg) {
self.general_free_regs.push(reg);
self.free_to_stack(&saved_sym)?;
} else {
self.general_used_regs.push((reg, saved_sym));
}
}
let old_float_used_regs =
std::mem::replace(&mut self.float_used_regs, bumpalo::vec![in self.env.arena]);
for (reg, saved_sym) in old_float_used_regs.into_iter() {
if CC::float_caller_saved(&reg) {
self.float_free_regs.push(reg);
self.free_to_stack(&saved_sym)?;
} else {
self.float_used_regs.push((reg, saved_sym));
}
}
Ok(())
}
fn update_jmp_imm32_offset(
&mut self,
tmp: &mut Vec<'a, u8>,
jmp_location: u64,
base_offset: u64,
target_offset: u64,
) {
tmp.clear();
let jmp_offset = target_offset as i32 - base_offset as i32;
ASM::jmp_imm32(tmp, jmp_offset);
for (i, byte) in tmp.iter().enumerate() {
self.buf[jmp_location as usize + i] = *byte;
}
}
}