ensure that when a switch case uses a callee-saved register, that register gets stored/restored properly

This commit is contained in:
Folkert 2023-11-25 20:18:37 +01:00
parent 85afcdd011
commit 104c44a754
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
2 changed files with 67 additions and 27 deletions

View file

@ -82,8 +82,8 @@ pub trait CallConv<GeneralReg: RegTrait, FloatReg: RegTrait, ASM: Assembler<Gene
fn setup_stack(
buf: &mut Vec<'_, u8>,
general_saved_regs: &[GeneralReg],
float_saved_regs: &[FloatReg],
saved_general_regs: &[GeneralReg],
saved_float_regs: &[FloatReg],
requested_stack_size: i32,
fn_call_stack_size: i32,
) -> i32;
@ -900,8 +900,11 @@ impl<
let mut out = bumpalo::vec![in self.env.arena];
// Setup stack.
let used_general_regs = self.storage_manager.general_used_callee_saved_regs();
let used_float_regs = self.storage_manager.float_used_callee_saved_regs();
let (used_general_regs, used_float_regs) = self
.storage_manager
.used_callee_saved_regs
.as_vecs(self.env.arena);
let aligned_stack_size = CC::setup_stack(
&mut out,
&used_general_regs,
@ -1199,6 +1202,12 @@ impl<
max_branch_stack_size =
std::cmp::max(max_branch_stack_size, self.storage_manager.stack_size());
base_storage.update_fn_call_stack_size(self.storage_manager.fn_call_stack_size());
// make sure that used callee-saved registers get saved/restored even if used in only
// one of the branches of the switch
base_storage
.used_callee_saved_regs
.extend(&self.storage_manager.used_callee_saved_regs);
}
self.storage_manager = base_storage;
self.literal_map = base_literal_map;

View file

@ -3,7 +3,7 @@ use crate::{
pointer_layouts, sign_extended_int_builtins, single_register_floats,
single_register_int_builtins, single_register_integers, single_register_layouts, Env,
};
use bumpalo::collections::Vec;
use bumpalo::collections::{CollectIn, Vec};
use roc_builtins::bitcode::{FloatWidth, IntWidth};
use roc_collections::all::{MutMap, MutSet};
use roc_error_macros::{internal_error, todo_lambda_erasure};
@ -118,10 +118,7 @@ pub struct StorageManager<
general_used_regs: Vec<'a, (GeneralReg, Symbol)>,
float_used_regs: Vec<'a, (FloatReg, Symbol)>,
// TODO: it probably would be faster to make these a list that linearly scans rather than hashing.
// used callee saved regs must be tracked for pushing and popping at the beginning/end of the function.
general_used_callee_saved_regs: MutSet<GeneralReg>,
float_used_callee_saved_regs: MutSet<FloatReg>,
pub(crate) used_callee_saved_regs: UsedCalleeRegisters<GeneralReg, FloatReg>,
free_stack_chunks: Vec<'a, (i32, u32)>,
stack_size: u32,
@ -152,16 +149,62 @@ pub fn new_storage_manager<
join_param_map: MutMap::default(),
general_free_regs: bumpalo::vec![in env.arena],
general_used_regs: bumpalo::vec![in env.arena],
general_used_callee_saved_regs: MutSet::default(),
// must be saved on entering a function, and restored before returning
used_callee_saved_regs: UsedCalleeRegisters::default(),
float_free_regs: bumpalo::vec![in env.arena],
float_used_regs: bumpalo::vec![in env.arena],
float_used_callee_saved_regs: MutSet::default(),
free_stack_chunks: bumpalo::vec![in env.arena],
stack_size: 0,
fn_call_stack_size: 0,
}
}
// optimization idea: use a bitset
#[derive(Debug, Clone)]
pub(crate) struct UsedCalleeRegisters<GeneralReg, FloatReg> {
general: MutSet<GeneralReg>,
float: MutSet<FloatReg>,
}
impl<GeneralReg: RegTrait, FloatReg: RegTrait> UsedCalleeRegisters<GeneralReg, FloatReg> {
fn clear(&mut self) {
self.general.clear();
self.float.clear();
}
fn insert_general(&mut self, reg: GeneralReg) -> bool {
self.general.insert(reg)
}
fn insert_float(&mut self, reg: FloatReg) -> bool {
self.float.insert(reg)
}
pub(crate) fn extend(&mut self, other: &Self) {
self.general.extend(other.general.iter().copied());
self.float.extend(other.float.iter().copied());
}
pub(crate) fn as_vecs<'a>(
&self,
arena: &'a bumpalo::Bump,
) -> (Vec<'a, GeneralReg>, Vec<'a, FloatReg>) {
(
self.general.iter().copied().collect_in(arena),
self.float.iter().copied().collect_in(arena),
)
}
}
impl<GeneralReg, FloatReg> Default for UsedCalleeRegisters<GeneralReg, FloatReg> {
fn default() -> Self {
Self {
general: Default::default(),
float: Default::default(),
}
}
}
impl<
'a,
'r,
@ -175,16 +218,16 @@ impl<
self.symbol_storage_map.clear();
self.allocation_map.clear();
self.join_param_map.clear();
self.general_used_callee_saved_regs.clear();
self.used_callee_saved_regs.clear();
self.general_free_regs.clear();
self.general_used_regs.clear();
self.general_free_regs
.extend_from_slice(CC::GENERAL_DEFAULT_FREE_REGS);
self.float_used_callee_saved_regs.clear();
self.float_free_regs.clear();
self.float_used_regs.clear();
self.float_free_regs
.extend_from_slice(CC::FLOAT_DEFAULT_FREE_REGS);
self.used_callee_saved_regs.clear();
self.free_stack_chunks.clear();
self.stack_size = 0;
self.fn_call_stack_size = 0;
@ -198,18 +241,6 @@ impl<
self.fn_call_stack_size
}
pub fn general_used_callee_saved_regs(&self) -> Vec<'a, GeneralReg> {
let mut used_regs = bumpalo::vec![in self.env.arena];
used_regs.extend(&self.general_used_callee_saved_regs);
used_regs
}
pub fn float_used_callee_saved_regs(&self) -> Vec<'a, FloatReg> {
let mut used_regs = bumpalo::vec![in self.env.arena];
used_regs.extend(&self.float_used_callee_saved_regs);
used_regs
}
/// Returns true if the symbol is storing a primitive value.
pub fn is_stored_primitive(&self, sym: &Symbol) -> bool {
matches!(
@ -223,7 +254,7 @@ impl<
fn get_general_reg(&mut self, buf: &mut Vec<'a, u8>) -> GeneralReg {
if let Some(reg) = self.general_free_regs.pop() {
if CC::general_callee_saved(&reg) {
self.general_used_callee_saved_regs.insert(reg);
self.used_callee_saved_regs.insert_general(reg);
}
reg
} else if !self.general_used_regs.is_empty() {
@ -240,7 +271,7 @@ impl<
fn get_float_reg(&mut self, buf: &mut Vec<'a, u8>) -> FloatReg {
if let Some(reg) = self.float_free_regs.pop() {
if CC::float_callee_saved(&reg) {
self.float_used_callee_saved_regs.insert(reg);
self.used_callee_saved_regs.insert_float(reg);
}
reg
} else if !self.float_used_regs.is_empty() {