Move LLVM scope ops to a separate module

This commit is contained in:
Ayaz Hafiz 2023-06-09 15:33:29 -05:00
parent 470ed119c2
commit 8d3d4ed9d8
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
7 changed files with 115 additions and 61 deletions

View file

@ -11,7 +11,6 @@ use crate::llvm::struct_::{self, struct_from_fields, RocStruct};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::basic_block::BasicBlock;
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::debug_info::{
@ -25,8 +24,8 @@ use inkwell::types::{
};
use inkwell::values::BasicValueEnum::{self};
use inkwell::values::{
BasicMetadataValueEnum, CallSiteValue, FunctionValue, InstructionValue, IntValue, PhiValue,
PointerValue, StructValue,
BasicMetadataValueEnum, CallSiteValue, FunctionValue, InstructionValue, IntValue, PointerValue,
StructValue,
};
use inkwell::OptimizationLevel;
use inkwell::{AddressSpace, IntPredicate};
@ -34,13 +33,13 @@ use morphic_lib::{
CalleeSpecVar, FuncName, FuncSpec, FuncSpecSolutions, ModSolutions, UpdateMode, UpdateModeVar,
};
use roc_builtins::bitcode::{self, FloatWidth, IntWidth};
use roc_collections::all::{ImMap, MutMap, MutSet};
use roc_collections::all::{MutMap, MutSet};
use roc_debug_flags::dbg_do;
#[cfg(debug_assertions)]
use roc_debug_flags::ROC_PRINT_LLVM_FN_VERIFICATION;
use roc_module::symbol::{Interns, ModuleId, Symbol};
use roc_module::symbol::{Interns, Symbol};
use roc_mono::ir::{
BranchInfo, CallType, CrashTag, EntryPoint, GlueLayouts, HostExposedLambdaSet, JoinPointId,
BranchInfo, CallType, CrashTag, EntryPoint, GlueLayouts, HostExposedLambdaSet,
ListLiteralElement, ModifyRc, OptLevel, ProcLayout, SingleEntryPoint,
};
use roc_mono::layout::{
@ -59,6 +58,7 @@ use super::intrinsics::{
LLVM_STACK_SAVE,
};
use super::lowlevel::run_higher_order_low_level;
use super::scope::Scope;
pub(crate) trait BuilderExt<'ctx> {
fn new_build_struct_gep(
@ -161,39 +161,6 @@ macro_rules! debug_info_init {
}};
}
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct Scope<'a, 'ctx> {
symbols: ImMap<Symbol, (InLayout<'a>, BasicValueEnum<'ctx>)>,
pub top_level_thunks: ImMap<Symbol, (ProcLayout<'a>, FunctionValue<'ctx>)>,
join_points: ImMap<JoinPointId, (BasicBlock<'ctx>, std::vec::Vec<PhiValue<'ctx>>)>,
}
impl<'a, 'ctx> Scope<'a, 'ctx> {
pub(crate) fn get(&self, symbol: &Symbol) -> Option<&(InLayout<'a>, BasicValueEnum<'ctx>)> {
self.symbols.get(symbol)
}
pub(crate) fn insert(&mut self, symbol: Symbol, value: (InLayout<'a>, BasicValueEnum<'ctx>)) {
self.symbols.insert(symbol, value);
}
pub(crate) fn insert_top_level_thunk(
&mut self,
symbol: Symbol,
layout: ProcLayout<'a>,
function_value: FunctionValue<'ctx>,
) {
self.top_level_thunks
.insert(symbol, (layout, function_value));
}
fn remove(&mut self, symbol: &Symbol) {
self.symbols.remove(symbol);
}
pub fn retain_top_level_thunks_for_module(&mut self, module_id: ModuleId) {
self.top_level_thunks
.retain(|s, _| s.module_id() == module_id);
}
}
#[derive(Debug, Clone, Copy)]
pub enum LlvmBackendMode {
/// Assumes primitives (roc_alloc, roc_panic, etc) are provided by the host
@ -919,7 +886,7 @@ fn small_str_ptr_width_4<'ctx>(env: &Env<'_, 'ctx, '_>, str_literal: &str) -> St
)
}
pub fn build_exp_call<'a, 'ctx>(
pub(crate) fn build_exp_call<'a, 'ctx>(
env: &Env<'a, 'ctx, '_>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
@ -1043,7 +1010,7 @@ fn struct_pointer_from_fields<'a, 'ctx, 'env, I>(
}
}
pub fn build_exp_expr<'a, 'ctx>(
pub(crate) fn build_exp_expr<'a, 'ctx>(
env: &Env<'a, 'ctx, '_>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
@ -2419,7 +2386,7 @@ pub fn store_roc_value<'a, 'ctx>(
}
}
pub fn build_exp_stmt<'a, 'ctx>(
pub(crate) fn build_exp_stmt<'a, 'ctx>(
env: &Env<'a, 'ctx, '_>,
layout_interner: &mut STLayoutInterner<'a>,
layout_ids: &mut LayoutIds<'a>,
@ -2636,7 +2603,7 @@ pub fn build_exp_stmt<'a, 'ctx>(
}
// store this join point
scope.join_points.insert(*id, (cont_block, joinpoint_args));
scope.insert_join_point(*id, cont_block, joinpoint_args);
// construct the blocks that may jump to this join point
build_exp_stmt(
@ -2655,11 +2622,9 @@ pub fn build_exp_stmt<'a, 'ctx>(
builder.position_at_end(cont_block);
// bind the values
let ref_join_points = &scope.join_points.get(id).unwrap().1;
for (phi_value, param) in ref_join_points.iter().zip(parameters.iter()) {
let value = phi_value.as_basic_value();
scope.symbols.insert(param.symbol, (param.layout, value));
}
scope
.bind_parameters_to_join_point(*id, parameters.iter())
.expect("join point not found, but it was inserted above");
// put the continuation in
let result = build_exp_stmt(
@ -2673,7 +2638,7 @@ pub fn build_exp_stmt<'a, 'ctx>(
);
// remove this join point again
scope.join_points.remove(id);
scope.remove_join_point(*id);
cont_block.move_after(phi_block).unwrap();
@ -2683,7 +2648,7 @@ pub fn build_exp_stmt<'a, 'ctx>(
Jump(join_point, arguments) => {
let builder = env.builder;
let context = env.context;
let (cont_block, argument_phi_values) = scope.join_points.get(join_point).unwrap();
let (cont_block, argument_phi_values) = scope.get_join_point(*join_point).unwrap();
let current_block = builder.get_insert_block().unwrap();
@ -2995,7 +2960,7 @@ pub fn build_exp_stmt<'a, 'ctx>(
}
}
pub fn load_symbol<'ctx>(scope: &Scope<'_, 'ctx>, symbol: &Symbol) -> BasicValueEnum<'ctx> {
pub(crate) fn load_symbol<'ctx>(scope: &Scope<'_, 'ctx>, symbol: &Symbol) -> BasicValueEnum<'ctx> {
match scope.get(symbol) {
Some((_, ptr)) => *ptr,
@ -4556,7 +4521,7 @@ fn make_exception_catching_wrapper<'a, 'ctx>(
wrapper_function
}
pub fn build_proc_headers<'a, 'r, 'ctx>(
pub(crate) fn build_proc_headers<'a, 'r, 'ctx>(
env: &'r Env<'a, 'ctx, '_>,
layout_interner: &'r mut STLayoutInterner<'a>,
mod_solutions: &'a ModSolutions,

View file

@ -1,7 +1,5 @@
use crate::llvm::bitcode::build_dec_wrapper;
use crate::llvm::build::{
allocate_with_refcount_help, cast_basic_basic, Env, RocFunctionCall, Scope,
};
use crate::llvm::build::{allocate_with_refcount_help, cast_basic_basic, Env, RocFunctionCall};
use crate::llvm::convert::basic_type_from_layout;
use inkwell::builder::Builder;
use inkwell::types::{BasicType, PointerType};
@ -17,6 +15,7 @@ use super::build::{
create_entry_block_alloca, load_roc_value, load_symbol, store_roc_value, BuilderExt,
};
use super::convert::zig_list_type;
use super::scope::Scope;
use super::struct_::struct_from_fields;
fn call_list_bitcode_fn_1<'ctx>(

View file

@ -20,9 +20,9 @@ use roc_region::all::Region;
use super::build::BuilderExt;
use super::build::{
add_func, load_roc_value, load_symbol_and_layout, use_roc_value, FunctionSpec, LlvmBackendMode,
Scope,
};
use super::convert::struct_type_from_union_layout;
use super::scope::Scope;
pub(crate) struct SharedMemoryPointer<'ctx>(PointerValue<'ctx>);

View file

@ -53,9 +53,9 @@ use crate::llvm::{
refcounting::PointerToRefcount,
};
use super::{build::throw_internal_exception, convert::zig_with_overflow_roc_dec};
use super::{build::throw_internal_exception, convert::zig_with_overflow_roc_dec, scope::Scope};
use super::{
build::{load_symbol, load_symbol_and_layout, Env, Scope},
build::{load_symbol, load_symbol_and_layout, Env},
convert::zig_dec_type,
};

View file

@ -10,4 +10,5 @@ mod intrinsics;
mod lowlevel;
pub mod refcounting;
mod scope;
mod struct_;

View file

@ -0,0 +1,88 @@
use inkwell::{
basic_block::BasicBlock,
values::{BasicValueEnum, FunctionValue, PhiValue},
};
use roc_collections::ImMap;
use roc_module::symbol::{ModuleId, Symbol};
use roc_mono::{
ir::{JoinPointId, Param, ProcLayout},
layout::InLayout,
};
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) struct Scope<'a, 'ctx> {
symbols: ImMap<Symbol, (InLayout<'a>, BasicValueEnum<'ctx>)>,
top_level_thunks: ImMap<Symbol, (ProcLayout<'a>, FunctionValue<'ctx>)>,
join_points: ImMap<JoinPointId, (BasicBlock<'ctx>, Vec<PhiValue<'ctx>>)>,
}
#[derive(Debug)]
pub(crate) struct JoinPointNotFound;
impl<'a, 'ctx> Scope<'a, 'ctx> {
pub fn get(&self, symbol: &Symbol) -> Option<&(InLayout<'a>, BasicValueEnum<'ctx>)> {
self.symbols.get(symbol)
}
pub fn insert(&mut self, symbol: Symbol, value: (InLayout<'a>, BasicValueEnum<'ctx>)) {
self.symbols.insert(symbol, value);
}
pub fn insert_top_level_thunk(
&mut self,
symbol: Symbol,
layout: ProcLayout<'a>,
function_value: FunctionValue<'ctx>,
) {
self.top_level_thunks
.insert(symbol, (layout, function_value));
}
pub fn remove(&mut self, symbol: &Symbol) {
self.symbols.remove(symbol);
}
pub fn retain_top_level_thunks_for_module(&mut self, module_id: ModuleId) {
self.top_level_thunks
.retain(|s, _| s.module_id() == module_id);
}
pub fn insert_join_point(
&mut self,
join_point_id: JoinPointId,
bb: BasicBlock<'ctx>,
phis: Vec<PhiValue<'ctx>>,
) {
self.join_points.insert(join_point_id, (bb, phis));
}
pub fn remove_join_point(&mut self, join_point_id: JoinPointId) {
self.join_points.remove(&join_point_id);
}
pub fn get_join_point(
&self,
join_point_id: JoinPointId,
) -> Option<&(BasicBlock<'ctx>, Vec<PhiValue<'ctx>>)> {
self.join_points.get(&join_point_id)
}
pub fn bind_parameters_to_join_point(
&mut self,
join_point_id: JoinPointId,
parameters: impl IntoIterator<Item = &'a Param<'a>>,
) -> Result<(), JoinPointNotFound> {
let ref_join_points = &self
.join_points
.get(&join_point_id)
.ok_or(JoinPointNotFound)?
.1;
for (phi_value, param) in ref_join_points.iter().zip(parameters.into_iter()) {
let value = phi_value.as_basic_value();
self.symbols.insert(param.symbol, (param.layout, value));
}
Ok(())
}
}

View file

@ -11,11 +11,12 @@ use roc_mono::layout::{InLayout, LayoutInterner, LayoutRepr, STLayoutInterner};
use crate::llvm::build::use_roc_value;
use super::{
build::{load_symbol_and_layout, BuilderExt, Env, Scope},
build::{load_symbol_and_layout, BuilderExt, Env},
convert::basic_type_from_layout,
scope::Scope,
};
pub enum RocStructType<'ctx> {
pub(crate) enum RocStructType<'ctx> {
/// The roc struct should be passed by rvalue.
ByValue(StructType<'ctx>),
}
@ -60,7 +61,7 @@ fn basic_type_from_record<'a, 'ctx>(
.struct_type(field_types.into_bump_slice(), false)
}
pub enum RocStruct<'ctx> {
pub(crate) enum RocStruct<'ctx> {
/// The roc struct should be passed by rvalue.
ByValue(StructValue<'ctx>),
}