Monomorphize record literals

This commit is contained in:
Richard Feldman 2024-11-16 13:24:32 -05:00
parent 074161733e
commit c883a6b5ac
No known key found for this signature in database
GPG key ID: DAC334802F365236
6 changed files with 171 additions and 341 deletions

View file

@ -1,5 +1,5 @@
use crate::{
mono_ir::{MonoExpr, MonoExprId, MonoExprs},
mono_ir::{sort_fields, MonoExpr, MonoExprId, MonoExprs},
mono_module::Interns,
mono_num::Number,
mono_type::{MonoType, MonoTypes, Primitive},
@ -9,8 +9,10 @@ use crate::{
use bumpalo::Bump;
use roc_can::expr::{Expr, IntValue};
use roc_collections::Push;
use roc_region::all::Region;
use roc_solve::module::Solved;
use roc_types::subs::Subs;
use soa::{Index, Slice};
pub struct Env<'a, 'c, 'd, 'i, 's, 't, P> {
arena: &'a Bump,
@ -52,7 +54,12 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
}
}
pub fn to_mono_expr(&mut self, can_expr: Expr) -> Option<MonoExprId> {
pub fn to_mono_expr(
&mut self,
can_expr: Expr,
region: Region,
get_expr_id: impl FnOnce() -> Option<MonoExprId>,
) -> Option<MonoExprId> {
let problems = &mut self.problems;
let mono_types = &mut self.mono_types;
let mut mono_from_var = |var| {
@ -67,11 +74,13 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
)
};
let mut add = |expr| self.mono_exprs.add(expr);
macro_rules! compiler_bug {
($problem:expr) => {{
($problem:expr, $region:expr) => {{
problems.push($problem);
Some(add(MonoExpr::CompilerBug($problem)))
Some(
self.mono_exprs
.add(MonoExpr::CompilerBug($problem), $region),
)
}};
}
@ -80,11 +89,14 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
Some(mono_id) => match mono_types.get(mono_id) {
MonoType::Primitive(primitive) => to_frac(*primitive, val, problems),
other => {
return compiler_bug!(Problem::NumSpecializedToWrongType(Some(*other)));
return compiler_bug!(
Problem::NumSpecializedToWrongType(Some(*other)),
region
);
}
},
None => {
return compiler_bug!(Problem::NumSpecializedToWrongType(None));
return compiler_bug!(Problem::NumSpecializedToWrongType(None), region);
}
},
Expr::Num(var, _str, int_value, _) | Expr::Int(var, _, _str, int_value, _) => {
@ -93,11 +105,14 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
Some(mono_id) => match mono_types.get(mono_id) {
MonoType::Primitive(primitive) => to_num(*primitive, int_value, problems),
other => {
return compiler_bug!(Problem::NumSpecializedToWrongType(Some(*other)));
return compiler_bug!(
Problem::NumSpecializedToWrongType(Some(*other)),
region
);
}
},
None => {
return compiler_bug!(Problem::NumSpecializedToWrongType(None));
return compiler_bug!(Problem::NumSpecializedToWrongType(None), region);
}
}
}
@ -109,11 +124,14 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
Some(mono_id) => match mono_types.get(mono_id) {
MonoType::Primitive(primitive) => char_to_int(*primitive, char, problems),
other => {
return compiler_bug!(Problem::CharSpecializedToWrongType(Some(*other)));
return compiler_bug!(
Problem::CharSpecializedToWrongType(Some(*other)),
region
);
}
},
None => {
return compiler_bug!(Problem::CharSpecializedToWrongType(None));
return compiler_bug!(Problem::CharSpecializedToWrongType(None), region);
}
},
Expr::Str(contents) => MonoExpr::Str(
@ -124,6 +142,42 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
// Empty records are zero-sized and should be discarded.
return None;
}
Expr::Record { record_var, fields } => {
// Sort the fields alphabetically by name.
let mut fields = sort_fields(fields, self.arena);
// Reserve a slice of IDs up front. This is so that we have a contiguous array
// of field IDs at the end of this, each corresponding to the appropriate record field.
let field_ids: Slice<MonoExprId> = self.mono_exprs.reserve_ids(fields.len() as u16);
let mut next_field_id = field_ids.start();
// Generate a MonoExpr for each field, using the reserved IDs so that we end up with
// that Slice being populated with the exprs in the fields, with the correct ordering.
fields.retain(|(_name, field)| {
let loc_expr = field.loc_expr;
self.to_mono_expr(loc_expr.value, loc_expr.region, || unsafe {
// Safety: This will run *at most* field.len() times, possibly less,
// so this will never create an index that's out of bounds.
let answer = MonoExprId::new_unchecked(Index::new(next_field_id));
next_field_id += 1;
Some(answer)
})
// Discard all the zero-sized fields as we go. We don't need to keep the contents
// of the Option because we already know it's the ID we passed in.
.is_some()
});
// If we dropped any fields because they were being zero-sized,
// drop the same number of reserved IDs so that they still line up.
field_ids.truncate(fields.len() as u16);
// If all fields ended up being zero-sized, this would compile to an empty record; return None.
let field_ids = field_ids.into_nonempty_slice()?;
let todo = (); // TODO: store debuginfo for the record type, including ideally type alias and/or opaque type names.
MonoExpr::Struct(field_ids)
}
_ => todo!(),
// Expr::List {
// elem_var,
@ -163,10 +217,6 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
// ret_var,
// } => todo!(),
// Expr::Closure(closure_data) => todo!(),
// Expr::Record { record_var, fields } => {
// // TODO *after* having converted to mono (dropping zero-sized fields), if no fields remain, then return None.
// todo!()
// }
// Expr::Tuple { tuple_var, elems } => todo!(),
// Expr::ImportParams(module_id, region, _) => todo!(),
// Expr::Crash { msg, ret_var } => todo!(),
@ -236,7 +286,11 @@ impl<'a, 'c, 'd, 'i, 's, 't, P: Push<Problem>> Env<'a, 'c, 'd, 'i, 's, 't, P> {
}
};
Some(add(mono_expr))
let mono_expr_id = get_expr_id()?;
self.mono_exprs.insert(mono_expr_id, mono_expr, region);
Some(mono_expr_id)
}
}

View file

@ -2,11 +2,13 @@ use crate::{
foreign_symbol::ForeignSymbolId, mono_module::InternedStrId, mono_num::Number,
mono_struct::MonoFieldId, mono_type::MonoTypeId, specialize_type::Problem,
};
use roc_can::expr::Recursive;
use roc_collections::soa::Slice;
use roc_module::low_level::LowLevel;
use bumpalo::Bump;
use roc_can::expr::{Field, Recursive};
use roc_module::symbol::Symbol;
use soa::{Id, NonEmptySlice, Slice2, Slice3};
use roc_module::{ident::Lowercase, low_level::LowLevel};
use roc_region::all::Region;
use soa::{Id, NonEmptySlice, Slice, Slice2, Slice3};
use std::iter;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct MonoPatternId {
@ -26,24 +28,30 @@ pub struct Def {
#[derive(Debug)]
pub struct MonoExprs {
// TODO convert to Vec2
exprs: Vec<MonoExpr>,
regions: Vec<Region>,
}
impl MonoExprs {
pub fn new() -> Self {
Self { exprs: Vec::new() }
Self {
exprs: Vec::new(),
regions: Vec::new(),
}
}
pub fn add(&mut self, expr: MonoExpr) -> MonoExprId {
pub fn add(&mut self, expr: MonoExpr, region: Region) -> MonoExprId {
let index = self.exprs.len() as u32;
self.exprs.push(expr);
self.regions.push(region);
MonoExprId {
inner: Id::new(index),
}
}
pub fn get(&self, id: MonoExprId) -> &MonoExpr {
pub fn get_expr(&self, id: MonoExprId) -> &MonoExpr {
debug_assert!(
self.exprs.get(id.inner.index()).is_some(),
"A MonoExprId was not found in MonoExprs. This should never happen!"
@ -52,6 +60,48 @@ impl MonoExprs {
// Safety: we should only ever hand out MonoExprIds that are valid indices into here.
unsafe { self.exprs.get_unchecked(id.inner.index() as usize) }
}
pub fn get_region(&self, id: MonoExprId) -> Region {
debug_assert!(
self.regions.get(id.inner.index()).is_some(),
"A MonoExprId was not found in MonoExprs. This should never happen!"
);
// Safety: we should only ever hand out MonoExprIds that are valid indices into here.
unsafe { *self.regions.get_unchecked(id.inner.index() as usize) }
}
pub fn reserve_ids(&self, len: u16) -> Slice<MonoExprId> {
let answer = Slice::new(self.exprs.len() as u32, len);
// These should all be overwritten; if they aren't, that's a problem!
self.exprs.extend(iter::repeat(MonoExpr::CompilerBug(
Problem::UninitializedReservedExpr,
)));
self.regions.extend(iter::repeat(Region::zero()));
answer
}
pub(crate) fn insert(&self, id: MonoExprId, mono_expr: MonoExpr, region: Region) {
debug_assert!(
self.exprs.get(id.inner.index()).is_some(),
"A MonoExprId was not found in MonoExprs. This should never happen!"
);
debug_assert!(
self.regions.get(id.inner.index()).is_some(),
"A MonoExprId was not found in MonoExprs. This should never happen!"
);
let index = id.inner.index() as usize;
// Safety: we should only ever hand out MonoExprIds that are valid indices into here.
unsafe {
*self.exprs.get_unchecked_mut(index) = mono_expr;
*self.regions.get_unchecked_mut(index) = region;
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -59,6 +109,12 @@ pub struct MonoExprId {
inner: Id<MonoExpr>,
}
impl MonoExprId {
pub(crate) unsafe fn new_unchecked(inner: Id<MonoExpr>) -> Self {
Self { inner }
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum MonoExpr {
Str(InternedStrId),
@ -136,12 +192,9 @@ pub enum MonoExpr {
recursive: Recursive,
},
/// Either a record literal or a tuple literal.
/// Rather than storing field names, we instead store a u16 field index.
Struct {
struct_type: MonoTypeId,
fields: Slice2<MonoFieldId, MonoTypeId>,
},
/// A record literal or a tuple literal.
/// These have already been sorted alphabetically.
Struct(NonEmptySlice<MonoExprId>),
/// The "crash" keyword. Importantly, during code gen we must mark this as "nothing happens after this"
Crash {
@ -247,3 +300,13 @@ pub enum DestructType {
Optional(MonoTypeId, MonoExprId),
Guard(MonoTypeId, MonoPatternId),
}
/// Sort the given fields alphabetically by name.
pub fn sort_fields<'a>(
fields: impl IntoIterator<Item = (Lowercase, Field)>,
arena: &'a Bump,
) -> bumpalo::collections::Vec<'a, (Lowercase, Field)> {
let mut fields = bumpalo::collections::Vec::from_iter_in(fields.into_iter(), arena);
fields.sort_by_key(|(name, _field)| name);
fields
}

View file

@ -1,310 +0,0 @@
use crate::expr::{self, Declarations, Expr, FunctionDef};
use crate::specialize_type::{MonoCache, Problem};
use roc_collections::VecMap;
use roc_module::symbol::Symbol;
use roc_types::subs::{Subs, Variable};
struct Context {
symbols: Symbol,
fresh_tvar: Box<dyn Fn() -> Variable>,
specializations: Specializations,
}
struct Specializations {
symbols: Symbol,
fenv: Vec<(Symbol, expr::FunctionDef)>,
specializations: Vec<(SpecializationKey, NeededSpecialization)>,
}
#[derive(PartialEq, Eq)]
struct SpecializationKey(Symbol, Variable);
struct NeededSpecialization {
def: expr::FunctionDef,
name_new: Symbol,
t_new: Variable,
specialized: Option<expr::Def>,
}
impl Specializations {
fn make(symbols: Symbol, program: &expr::Declarations) -> Self {
let fenv = program
.iter_top_down()
.filter_map(|(_, tag)| match tag {
expr::DeclarationTag::Function(idx)
| expr::DeclarationTag::Recursive(idx)
| expr::DeclarationTag::TailRecursive(idx) => {
let func = &program.function_bodies[idx.index()];
Some((func.value.name, func.value.clone()))
}
_ => None,
})
.collect();
Specializations {
symbols,
fenv,
specializations: Vec::new(),
}
}
fn specialize_fn(
&mut self,
mono_cache: &mut MonoCache,
name: Symbol,
t_new: Variable,
) -> Option<Symbol> {
let specialization = (name, t_new);
if let Some((_, needed)) = self
.specializations
.iter()
.find(|(key, _)| *key == specialization)
{
Some(needed.name_new)
} else {
let def = self.fenv.iter().find(|(n, _)| *n == name)?.1.clone();
let name_new = self.symbols.fresh_symbol_named(name);
let needed_specialization = NeededSpecialization {
def,
name_new,
t_new,
specialized: None,
};
self.specializations
.push((specialization, needed_specialization));
Some(name_new)
}
}
fn next_needed_specialization(&mut self) -> Option<&mut NeededSpecialization> {
self.specializations.iter_mut().find_map(|(_, ns)| {
if ns.specialized.is_none() {
Some(ns)
} else {
None
}
})
}
fn solved_specializations(&self) -> Vec<expr::Def> {
self.specializations
.iter()
.filter_map(|(_, ns)| ns.specialized.clone())
.collect()
}
}
fn specialize_expr(
ctx: &mut Context,
ty_cache: &mut Subs,
mono_cache: &mut MonoCache,
expr: &Expr,
) -> Expr {
match expr {
Expr::Var(x) => {
if let Some(y) = ctx
.specializations
.specialize_fn(mono_cache, *x, expr.get_type())
{
Expr::Var(y)
} else {
expr.clone()
}
}
Expr::Int(i) => Expr::Int(*i),
Expr::Str(s) => Expr::Str(s.clone()),
Expr::Tag(t, args) => {
let new_args = args
.iter()
.map(|a| specialize_expr(ctx, ty_cache, mono_cache, a))
.collect();
Expr::Tag(*t, new_args)
}
Expr::Record(fields) => {
let new_fields = fields
.iter()
.map(|(f, e)| (*f, specialize_expr(ctx, ty_cache, mono_cache, e)))
.collect();
Expr::Record(new_fields)
}
Expr::Access(e, f) => {
Expr::Access(Box::new(specialize_expr(ctx, ty_cache, mono_cache, e)), *f)
}
Expr::Let(def, rest) => {
let new_def = match def {
expr::Def::Letfn(f) => expr::Def::Letfn(FunctionDef {
recursive: f.recursive,
bind: (
mono_cache.monomorphize_var(ty_cache, &mut Vec::new(), f.bind.0),
f.bind.1,
),
arg: (
mono_cache.monomorphize_var(ty_cache, &mut Vec::new(), f.arg.0),
f.arg.1,
),
body: Box::new(specialize_expr(ctx, ty_cache, mono_cache, &f.body)),
}),
expr::Def::Letval(v) => expr::Def::Letval(expr::Letval {
bind: (
mono_cache.monomorphize_var(ty_cache, &mut Vec::new(), v.bind.0),
v.bind.1,
),
body: Box::new(specialize_expr(ctx, ty_cache, mono_cache, &v.body)),
}),
};
let new_rest = Box::new(specialize_expr(ctx, ty_cache, mono_cache, rest));
Expr::Let(Box::new(new_def), new_rest)
}
Expr::Clos { arg, body } => {
let new_arg = (
mono_cache.monomorphize_var(ty_cache, &mut Vec::new(), arg.0),
arg.1,
);
let new_body = Box::new(specialize_expr(ctx, ty_cache, mono_cache, body));
Expr::Clos {
arg: new_arg,
body: new_body,
}
}
Expr::Call(f, a) => {
let new_f = Box::new(specialize_expr(ctx, ty_cache, mono_cache, f));
let new_a = Box::new(specialize_expr(ctx, ty_cache, mono_cache, a));
Expr::Call(new_f, new_a)
}
Expr::KCall(kfn, args) => {
let new_args = args
.iter()
.map(|a| specialize_expr(ctx, ty_cache, mono_cache, a))
.collect();
Expr::KCall(*kfn, new_args)
}
Expr::When(e, branches) => {
let new_e = Box::new(specialize_expr(ctx, ty_cache, mono_cache, e));
let new_branches = branches
.iter()
.map(|(p, e)| {
let new_p = specialize_pattern(ctx, ty_cache, mono_cache, p);
let new_e = specialize_expr(ctx, ty_cache, mono_cache, e);
(new_p, new_e)
})
.collect();
Expr::When(new_e, new_branches)
}
}
}
fn specialize_pattern(
ctx: &mut Context,
ty_cache: &mut Subs,
mono_cache: &mut MonoCache,
pattern: &expr::Pattern,
) -> expr::Pattern {
match pattern {
expr::Pattern::PVar(x) => expr::Pattern::PVar(*x),
expr::Pattern::PTag(tag, args) => {
let new_args = args
.iter()
.map(|a| specialize_pattern(ctx, ty_cache, mono_cache, a))
.collect();
expr::Pattern::PTag(*tag, new_args)
}
}
}
fn specialize_let_fn(
ctx: &mut Context,
ty_cache: &mut Subs,
mono_cache: &mut MonoCache,
t_new: Variable,
name_new: Symbol,
f: &expr::FunctionDef,
) -> expr::Def {
mono_cache.monomorphize_var(ty_cache, &mut Vec::new(), t_new);
let t = mono_cache.monomorphize_var(ty_cache, &mut Vec::new(), f.bind.0);
let t_arg = mono_cache.monomorphize_var(ty_cache, &mut Vec::new(), f.arg.0);
let body = specialize_expr(ctx, ty_cache, mono_cache, &f.body);
expr::Def::Letfn(FunctionDef {
recursive: f.recursive,
bind: (t, name_new),
arg: (t_arg, f.arg.1),
body: Box::new(body),
})
}
fn specialize_let_val(ctx: &mut Context, v: &expr::Letval) -> expr::Def {
let mut ty_cache = Subs::new();
let mut mono_cache = MonoCache::from_subs(&ty_cache);
let t = mono_cache.monomorphize_var(&mut ty_cache, &mut Vec::new(), v.bind.0);
let body = specialize_expr(ctx, &mut ty_cache, &mut mono_cache, &v.body);
expr::Def::Letval(expr::Letval {
bind: (t, v.bind.1),
body: Box::new(body),
})
}
fn specialize_run_def(ctx: &mut Context, run: &expr::Run) -> expr::Run {
let mut ty_cache = Subs::new();
let mut mono_cache = MonoCache::from_subs(&ty_cache);
let t = mono_cache.monomorphize_var(&mut ty_cache, &mut Vec::new(), run.bind.0);
let body = specialize_expr(ctx, &mut ty_cache, &mut mono_cache, &run.body);
expr::Run {
bind: (t, run.bind.1),
body: Box::new(body),
ty: run.ty,
}
}
fn make_context(
symbols: Symbol,
fresh_tvar: Box<dyn Fn() -> Variable>,
program: &expr::Declarations,
) -> Context {
Context {
symbols,
fresh_tvar,
specializations: Specializations::make(symbols, program),
}
}
fn loop_specializations(ctx: &mut Context) {
while let Some(needed) = ctx.specializations.next_needed_specialization() {
let mut ty_cache = Subs::new();
let mut mono_cache = MonoCache::from_subs(&ty_cache);
let def = specialize_let_fn(
ctx,
&mut ty_cache,
&mut mono_cache,
needed.t_new,
needed.name_new,
&needed.def,
);
needed.specialized = Some(def);
}
}
pub fn lower(ctx: &mut Context, program: &expr::Declarations) -> expr::Declarations {
let mut new_program = expr::Declarations::new();
for (idx, tag) in program.iter_top_down() {
match tag {
expr::DeclarationTag::Value => {
let def = specialize_let_val(ctx, &program.expressions[idx]);
new_program.push_def(def);
}
expr::DeclarationTag::Run(run_idx) => {
let run = specialize_run_def(ctx, &program.expressions[run_idx.index()]);
new_program.push_run(run);
}
_ => {}
}
}
loop_specializations(ctx);
let other_defs = ctx.specializations.solved_specializations();
new_program.extend(other_defs);
new_program
}

View file

@ -35,6 +35,7 @@ pub enum Problem {
Option<MonoType>, // `None` means it specialized to Unit
),
BadNumTypeParam,
UninitializedReservedExpr,
}
/// For MonoTypes that are records, store their field indices.

View file

@ -95,7 +95,7 @@ pub fn expect_no_expr(input: impl AsRef<str>) {
let arena = Bump::new();
let mut interns = Interns::new();
let out = specialize_expr(&arena, input.as_ref(), &mut interns);
let actual = out.mono_expr_id.map(|id| out.mono_exprs.get(id));
let actual = out.mono_expr_id.map(|id| out.mono_exprs.get_expr(id));
assert_eq!(None, actual, "This input expr should have specialized to being dicarded as zero-sized, but it didn't: {:?}", input.as_ref());
}
@ -118,7 +118,7 @@ pub fn expect_mono_expr_with_interns<T>(
.mono_expr_id
.expect("This input expr should not have been discarded as zero-sized, but it was discarded: {input:?}");
let actual_expr = out.mono_exprs.get(mono_expr_id); // Must run first, to populate string interns!
let actual_expr = out.mono_exprs.get_expr(mono_expr_id); // Must run first, to populate string interns!
let expected_expr = to_mono_expr(from_interns(&arena, &string_interns));