mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-03 00:24:34 +00:00
Specialize polymorphic non-function expressions
This commit fixes a long-standing bug wherein bindings to polymorphic, non-function expressions would be lowered at binding site, rather than being specialized at the call site. Concretely, consider the program ``` main = n = 1 idU8 : U8 -> U8 idU8 = \m -> m idU8 n ``` Prior to this commit, we would lower `n = 1` as part of the IR, and the `n` at the call site `idU8 n` would reference the lowered definition. However, at the definition site, `1` has the polymorphic type `Num *` - it is not until the the call site that we are able to refine the type bound by `n`, but at that point it's too late. Since the default layout for `Num *` is a signed 64-bit int, we would generate IR like ``` procedure main(): let App.n : Builtin(Int(I64)) = 1i64; ... let App.5 : Builtin(Int(U8)) = CallByName Add.idU8 App.n; ret App.5; ``` But we know `idU8` expects a `u8`; giving it an `i64` is nonsense. Indeed this would trigger LLVM miscompilations later on. To remedy this, we now keep a sidecar table that maps symbols to the polymorphic expression they reference, when they do so. We then specialize references to symbols on the fly at usage sites, similar to how we specialize function usages. Looking at our example, the definition `n = 1` is now never lowered to the IR directly. We only generate code for `1` at each place `n` is referenced. As a larger example, you can imagine that ``` main = n = 1 asU8 : U8 -> U8 asU32 : U32 -> U8 asU8 n + asU32 n ``` is lowered to the moral equivalent of ``` main = asU8 : U8 -> U8 asU32 : U32 -> U8 asU8 1 + asU32 1 ``` Moreover, transient usages of polymorphic expressions are lowered successfully with this approach. See for example the `monomorphized_tag_with_polymorphic_arg_and_monomorphic_arg` test in this commit, which checks that ``` main = mono : U8 mono = 15 poly = A wrap = Wrapped poly mono useWrap1 : [Wrapped [A] U8, Other] -> U8 useWrap1 = \w -> when w is Wrapped A n -> n Other -> 0 useWrap2 : [Wrapped [A, B] U8] -> U8 useWrap2 = \w -> when w is Wrapped A n -> n Wrapped B _ -> 0 useWrap1 wrap * useWrap2 wrap ``` has proper code generated for it, in the presence of the polymorphic `wrap` which references the polymorphic `poly`. https://github.com/rtfeldman/roc/pull/2347 had a different approach to this - polymorphic expressions would be converted to (possibly capturing) thunks. This has the benefit of reducing code size if there are many polymorphic usages, but may make the generated code slower and makes integration with the existing IR implementation harder. In practice I think the average number of polymorphic usages of an expression will be very small. Closes https://github.com/rtfeldman/roc/issues/2336 Closes https://github.com/rtfeldman/roc/issues/2254 Closes https://github.com/rtfeldman/roc/issues/2344
This commit is contained in:
parent
3342090e7b
commit
a5de224626
23 changed files with 655 additions and 68 deletions
|
@ -207,6 +207,67 @@ impl<'a> PartialProc<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum PartialExprLink {
|
||||
Aliases(Symbol),
|
||||
Expr(roc_can::expr::Expr, Variable),
|
||||
}
|
||||
|
||||
/// A table of symbols to polymorphic expressions. For example, in the program
|
||||
///
|
||||
/// n = 1
|
||||
///
|
||||
/// asU8 : U8 -> U8
|
||||
/// asU8 = \_ -> 1
|
||||
///
|
||||
/// asU32 : U32 -> U8
|
||||
/// asU32 = \_ -> 1
|
||||
///
|
||||
/// asU8 n + asU32 n
|
||||
///
|
||||
/// The expression bound by `n` doesn't have a definite layout until it is used
|
||||
/// at the call sites `asU8 n`, `asU32 n`.
|
||||
///
|
||||
/// Polymorphic *functions* are stored in `PartialProc`s, since functions are
|
||||
/// non longer first-class once we finish lowering to the IR.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PartialExprs(BumpMap<Symbol, PartialExprLink>);
|
||||
|
||||
impl PartialExprs {
|
||||
pub fn new_in<'a>(arena: &'a Bump) -> Self {
|
||||
Self(BumpMap::new_in(arena))
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, symbol: Symbol, expr: roc_can::expr::Expr, expr_var: Variable) {
|
||||
self.0.insert(symbol, PartialExprLink::Expr(expr, expr_var));
|
||||
}
|
||||
|
||||
pub fn insert_alias(&mut self, symbol: Symbol, aliases: Symbol) {
|
||||
self.0.insert(symbol, PartialExprLink::Aliases(aliases));
|
||||
}
|
||||
|
||||
pub fn contains(&self, symbol: Symbol) -> bool {
|
||||
self.0.contains_key(&symbol)
|
||||
}
|
||||
|
||||
pub fn get(&mut self, mut symbol: Symbol) -> Option<(&roc_can::expr::Expr, Variable)> {
|
||||
// In practice the alias chain is very short
|
||||
loop {
|
||||
match self.0.get(&symbol) {
|
||||
None => {
|
||||
return None;
|
||||
}
|
||||
Some(&PartialExprLink::Aliases(real_symbol)) => {
|
||||
symbol = real_symbol;
|
||||
}
|
||||
Some(PartialExprLink::Expr(expr, var)) => {
|
||||
return Some((expr, *var));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum CapturedSymbols<'a> {
|
||||
None,
|
||||
|
@ -668,6 +729,7 @@ impl<'a> Specialized<'a> {
|
|||
#[derive(Clone, Debug)]
|
||||
pub struct Procs<'a> {
|
||||
pub partial_procs: PartialProcs<'a>,
|
||||
pub partial_exprs: PartialExprs,
|
||||
pub imported_module_thunks: &'a [Symbol],
|
||||
pub module_thunks: &'a [Symbol],
|
||||
pending_specializations: PendingSpecializations<'a>,
|
||||
|
@ -680,6 +742,7 @@ impl<'a> Procs<'a> {
|
|||
pub fn new_in(arena: &'a Bump) -> Self {
|
||||
Self {
|
||||
partial_procs: PartialProcs::new_in(arena),
|
||||
partial_exprs: PartialExprs::new_in(arena),
|
||||
imported_module_thunks: &[],
|
||||
module_thunks: &[],
|
||||
pending_specializations: PendingSpecializations::Finding(Suspended::new_in(arena)),
|
||||
|
@ -3103,16 +3166,20 @@ pub fn with_hole<'a>(
|
|||
_ => {}
|
||||
}
|
||||
|
||||
// continue with the default path
|
||||
let mut stmt = with_hole(
|
||||
env,
|
||||
cont.value,
|
||||
variable,
|
||||
procs,
|
||||
layout_cache,
|
||||
assigned,
|
||||
hole,
|
||||
);
|
||||
let build_rest =
|
||||
|env: &mut Env<'a, '_>,
|
||||
procs: &mut Procs<'a>,
|
||||
layout_cache: &mut LayoutCache<'a>| {
|
||||
with_hole(
|
||||
env,
|
||||
cont.value,
|
||||
variable,
|
||||
procs,
|
||||
layout_cache,
|
||||
assigned,
|
||||
hole,
|
||||
)
|
||||
};
|
||||
|
||||
// a variable is aliased
|
||||
if let roc_can::expr::Expr::Var(original) = def.loc_expr.value {
|
||||
|
@ -3124,18 +3191,17 @@ pub fn with_hole<'a>(
|
|||
//
|
||||
// foo = RBTRee.empty
|
||||
|
||||
stmt = handle_variable_aliasing(
|
||||
handle_variable_aliasing(
|
||||
env,
|
||||
procs,
|
||||
layout_cache,
|
||||
def.expr_var,
|
||||
symbol,
|
||||
original,
|
||||
stmt,
|
||||
);
|
||||
|
||||
stmt
|
||||
build_rest,
|
||||
)
|
||||
} else {
|
||||
let rest = build_rest(env, procs, layout_cache);
|
||||
with_hole(
|
||||
env,
|
||||
def.loc_expr.value,
|
||||
|
@ -3143,7 +3209,7 @@ pub fn with_hole<'a>(
|
|||
procs,
|
||||
layout_cache,
|
||||
symbol,
|
||||
env.arena.alloc(stmt),
|
||||
env.arena.alloc(rest),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
|
@ -3328,6 +3394,7 @@ pub fn with_hole<'a>(
|
|||
let mut can_fields = Vec::with_capacity_in(fields.len(), env.arena);
|
||||
|
||||
enum Field {
|
||||
// TODO: rename this since it can handle unspecialized expressions now too
|
||||
Function(Symbol, Variable),
|
||||
ValueSymbol,
|
||||
Field(roc_can::expr::Field),
|
||||
|
@ -3338,7 +3405,7 @@ pub fn with_hole<'a>(
|
|||
use ReuseSymbol::*;
|
||||
match fields.remove(&label) {
|
||||
Some(field) => match can_reuse_symbol(env, procs, &field.loc_expr.value) {
|
||||
Imported(symbol) | LocalFunction(symbol) => {
|
||||
Imported(symbol) | LocalFunction(symbol) | UnspecializedExpr(symbol) => {
|
||||
field_symbols.push(symbol);
|
||||
can_fields.push(Field::Function(symbol, variable));
|
||||
}
|
||||
|
@ -4064,6 +4131,9 @@ pub fn with_hole<'a>(
|
|||
LocalFunction(_) => {
|
||||
unreachable!("if this was known to be a function, we would not be here")
|
||||
}
|
||||
UnspecializedExpr(_) => {
|
||||
unreachable!("if this was known to be an unspecialized expression, we would not be here")
|
||||
}
|
||||
Imported(thunk_name) => {
|
||||
debug_assert!(procs.is_imported_module_thunk(thunk_name));
|
||||
|
||||
|
@ -5023,6 +5093,31 @@ fn register_capturing_closure<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
fn is_literal_like(expr: &roc_can::expr::Expr) -> bool {
|
||||
use roc_can::expr::Expr::*;
|
||||
matches!(
|
||||
expr,
|
||||
Num(..)
|
||||
| Int(..)
|
||||
| Float(..)
|
||||
| List { .. }
|
||||
| Str(_)
|
||||
| ZeroArgumentTag { .. }
|
||||
| Tag { .. }
|
||||
| Record { .. }
|
||||
)
|
||||
}
|
||||
|
||||
fn expr_is_polymorphic<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
expr: &roc_can::expr::Expr,
|
||||
expr_var: Variable,
|
||||
) -> bool {
|
||||
// TODO: I don't think we need the `is_literal_like` check, but taking it slow for now...
|
||||
let is_flex_or_rigid = |c: &Content| matches!(c, Content::FlexVar(_) | Content::RigidVar(_));
|
||||
is_literal_like(expr) && env.subs.var_contains_content(expr_var, is_flex_or_rigid)
|
||||
}
|
||||
|
||||
pub fn from_can<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
variable: Variable,
|
||||
|
@ -5177,19 +5272,26 @@ pub fn from_can<'a>(
|
|||
// or
|
||||
//
|
||||
// foo = RBTRee.empty
|
||||
let mut rest = from_can(env, def.expr_var, cont.value, procs, layout_cache);
|
||||
|
||||
rest = handle_variable_aliasing(
|
||||
// TODO: right now we need help out rustc with the closure types;
|
||||
// it isn't able to infer the right lifetime bounds. See if we
|
||||
// can remove the annotations in the future.
|
||||
let build_rest =
|
||||
|env: &mut Env<'a, '_>,
|
||||
procs: &mut Procs<'a>,
|
||||
layout_cache: &mut LayoutCache<'a>| {
|
||||
from_can(env, def.expr_var, cont.value, procs, layout_cache)
|
||||
};
|
||||
|
||||
return handle_variable_aliasing(
|
||||
env,
|
||||
procs,
|
||||
layout_cache,
|
||||
def.expr_var,
|
||||
*symbol,
|
||||
original,
|
||||
rest,
|
||||
build_rest,
|
||||
);
|
||||
|
||||
return rest;
|
||||
}
|
||||
roc_can::expr::Expr::LetNonRec(nested_def, nested_cont, nested_annotation) => {
|
||||
use roc_can::expr::Expr::*;
|
||||
|
@ -5273,6 +5375,21 @@ pub fn from_can<'a>(
|
|||
|
||||
return from_can(env, variable, new_outer, procs, layout_cache);
|
||||
}
|
||||
ref body if expr_is_polymorphic(env, body, def.expr_var) => {
|
||||
// This is a pattern like
|
||||
//
|
||||
// n = 1
|
||||
// asU8 n
|
||||
//
|
||||
// At the definition site `n = 1` we only know `1` to have the type `[Int *]`,
|
||||
// which won't be refined until the call `asU8 n`. Add it as a partial expression
|
||||
// that will be specialized at each concrete usage site.
|
||||
procs
|
||||
.partial_exprs
|
||||
.insert(*symbol, def.loc_expr.value, def.expr_var);
|
||||
|
||||
return from_can(env, variable, cont.value, procs, layout_cache);
|
||||
}
|
||||
_ => {
|
||||
let rest = from_can(env, variable, cont.value, procs, layout_cache);
|
||||
return with_hole(
|
||||
|
@ -6290,6 +6407,7 @@ enum ReuseSymbol {
|
|||
Imported(Symbol),
|
||||
LocalFunction(Symbol),
|
||||
Value(Symbol),
|
||||
UnspecializedExpr(Symbol),
|
||||
NotASymbol,
|
||||
}
|
||||
|
||||
|
@ -6307,6 +6425,8 @@ fn can_reuse_symbol<'a>(
|
|||
Imported(symbol)
|
||||
} else if procs.partial_procs.contains_key(symbol) {
|
||||
LocalFunction(symbol)
|
||||
} else if procs.partial_exprs.contains(symbol) {
|
||||
UnspecializedExpr(symbol)
|
||||
} else {
|
||||
Value(symbol)
|
||||
}
|
||||
|
@ -6326,15 +6446,29 @@ fn possible_reuse_symbol<'a>(
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_variable_aliasing<'a>(
|
||||
fn handle_variable_aliasing<'a, BuildRest>(
|
||||
env: &mut Env<'a, '_>,
|
||||
procs: &mut Procs<'a>,
|
||||
layout_cache: &mut LayoutCache<'a>,
|
||||
variable: Variable,
|
||||
left: Symbol,
|
||||
right: Symbol,
|
||||
mut result: Stmt<'a>,
|
||||
) -> Stmt<'a> {
|
||||
build_rest: BuildRest,
|
||||
) -> Stmt<'a>
|
||||
where
|
||||
BuildRest: FnOnce(&mut Env<'a, '_>, &mut Procs<'a>, &mut LayoutCache<'a>) -> Stmt<'a>,
|
||||
{
|
||||
if procs.partial_exprs.contains(right) {
|
||||
// If `right` links to a partial expression, make sure we link `left` to it as well, so
|
||||
// that usages of it will be specialized when building the rest of the program.
|
||||
procs.partial_exprs.insert_alias(left, right);
|
||||
return build_rest(env, procs, layout_cache);
|
||||
}
|
||||
|
||||
// Otherwise we're dealing with an alias to something that doesn't need to be specialized, or
|
||||
// whose usages will already be specialized in the rest of the program. Let's just build the
|
||||
// rest of the program now to get our hole.
|
||||
let mut result = build_rest(env, procs, layout_cache);
|
||||
if procs.is_imported_module_thunk(right) {
|
||||
// if this is an imported symbol, then we must make sure it is
|
||||
// specialized, and wrap the original in a function pointer.
|
||||
|
@ -6392,6 +6526,7 @@ fn let_empty_struct<'a>(assigned: Symbol, hole: &'a Stmt<'a>) -> Stmt<'a> {
|
|||
}
|
||||
|
||||
/// If the symbol is a function, make sure it is properly specialized
|
||||
// TODO: rename this now that we handle polymorphic non-function expressions too
|
||||
fn reuse_function_symbol<'a>(
|
||||
env: &mut Env<'a, '_>,
|
||||
procs: &mut Procs<'a>,
|
||||
|
@ -6401,6 +6536,35 @@ fn reuse_function_symbol<'a>(
|
|||
result: Stmt<'a>,
|
||||
original: Symbol,
|
||||
) -> Stmt<'a> {
|
||||
if let Some((expr, expr_var)) = procs.partial_exprs.get(original) {
|
||||
// Specialize the expression type now, based off the `arg_var` we've been given.
|
||||
// TODO: cache the specialized result
|
||||
let snapshot = env.subs.snapshot();
|
||||
let cache_snapshot = layout_cache.snapshot();
|
||||
let _unified = roc_unify::unify::unify(
|
||||
env.subs,
|
||||
arg_var.unwrap(),
|
||||
expr_var,
|
||||
roc_unify::unify::Mode::Eq,
|
||||
);
|
||||
|
||||
let result = with_hole(
|
||||
env,
|
||||
expr.clone(),
|
||||
expr_var,
|
||||
procs,
|
||||
layout_cache,
|
||||
symbol,
|
||||
env.arena.alloc(result),
|
||||
);
|
||||
|
||||
// Restore the prior state so as not to interfere with future specializations.
|
||||
env.subs.rollback_to(snapshot);
|
||||
layout_cache.rollback_to(cache_snapshot);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
match procs.get_partial_proc(original) {
|
||||
None => {
|
||||
match arg_var {
|
||||
|
@ -6566,7 +6730,7 @@ fn assign_to_symbol<'a>(
|
|||
) -> Stmt<'a> {
|
||||
use ReuseSymbol::*;
|
||||
match can_reuse_symbol(env, procs, &loc_arg.value) {
|
||||
Imported(original) | LocalFunction(original) => {
|
||||
Imported(original) | LocalFunction(original) | UnspecializedExpr(original) => {
|
||||
// for functions we must make sure they are specialized correctly
|
||||
reuse_function_symbol(
|
||||
env,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue