Specialize polymorphic non-function expressions

This commit fixes a long-standing bug wherein bindings to polymorphic,
non-function expressions would be lowered at binding site, rather than
being specialized at the call site.

Concretely, consider the program

```
main =
    n = 1

    idU8 : U8 -> U8
    idU8 = \m -> m

    idU8 n
```

Prior to this commit, we would lower `n = 1` as part of the IR, and the
`n` at the call site `idU8 n` would reference the lowered definition.
However, at the definition site, `1` has the polymorphic type `Num *` -
it is not until the the call site that we are able to refine the type
bound by `n`, but at that point it's too late. Since the default layout
for `Num *` is a signed 64-bit int, we would generate IR like

```
procedure main():
    let App.n : Builtin(Int(I64)) = 1i64;
    ...
    let App.5 : Builtin(Int(U8)) = CallByName Add.idU8 App.n;
    ret App.5;
```

But we know `idU8` expects a `u8`; giving it an `i64` is nonsense.
Indeed this would trigger LLVM miscompilations later on.

To remedy this, we now keep a sidecar table that maps symbols to the
polymorphic expression they reference, when they do so. We then
specialize references to symbols on the fly at usage sites, similar to
how we specialize function usages.

Looking at our example, the definition `n = 1` is now never lowered to
the IR directly. We only generate code for `1` at each place `n` is
referenced. As a larger example, you can imagine that

```
main =
    n = 1

    asU8 : U8 -> U8
    asU32 : U32 -> U8

    asU8 n + asU32 n
```

is lowered to the moral equivalent of

```
main =
    asU8 : U8 -> U8
    asU32 : U32 -> U8

    asU8 1 + asU32 1
```

Moreover, transient usages of polymorphic expressions are lowered
successfully with this approach. See for example the
`monomorphized_tag_with_polymorphic_arg_and_monomorphic_arg` test in
this commit, which checks that

```
main =
    mono : U8
    mono = 15
    poly = A
    wrap = Wrapped poly mono

    useWrap1 : [Wrapped [A] U8, Other] -> U8
    useWrap1 =
        \w -> when w is
            Wrapped A n -> n
            Other -> 0

    useWrap2 : [Wrapped [A, B] U8] -> U8
    useWrap2 =
        \w -> when w is
            Wrapped A n -> n
            Wrapped B _ -> 0

    useWrap1 wrap * useWrap2 wrap
```

has proper code generated for it, in the presence of the polymorphic
`wrap` which references the polymorphic `poly`.

https://github.com/rtfeldman/roc/pull/2347 had a different approach to
this - polymorphic expressions would be converted to (possibly capturing) thunks.
This has the benefit of reducing code size if there are many polymorphic
usages, but may make the generated code slower and makes integration
with the existing IR implementation harder. In practice I think the
average number of polymorphic usages of an expression will be very
small.

Closes https://github.com/rtfeldman/roc/issues/2336
Closes https://github.com/rtfeldman/roc/issues/2254
Closes https://github.com/rtfeldman/roc/issues/2344
This commit is contained in:
ayazhafiz 2022-01-19 22:52:15 -05:00
parent 3342090e7b
commit a5de224626
23 changed files with 655 additions and 68 deletions

View file

@ -207,6 +207,67 @@ impl<'a> PartialProc<'a> {
}
}
#[derive(Clone, Debug)]
enum PartialExprLink {
Aliases(Symbol),
Expr(roc_can::expr::Expr, Variable),
}
/// A table of symbols to polymorphic expressions. For example, in the program
///
/// n = 1
///
/// asU8 : U8 -> U8
/// asU8 = \_ -> 1
///
/// asU32 : U32 -> U8
/// asU32 = \_ -> 1
///
/// asU8 n + asU32 n
///
/// The expression bound by `n` doesn't have a definite layout until it is used
/// at the call sites `asU8 n`, `asU32 n`.
///
/// Polymorphic *functions* are stored in `PartialProc`s, since functions are
/// non longer first-class once we finish lowering to the IR.
#[derive(Clone, Debug)]
pub struct PartialExprs(BumpMap<Symbol, PartialExprLink>);
impl PartialExprs {
pub fn new_in<'a>(arena: &'a Bump) -> Self {
Self(BumpMap::new_in(arena))
}
pub fn insert(&mut self, symbol: Symbol, expr: roc_can::expr::Expr, expr_var: Variable) {
self.0.insert(symbol, PartialExprLink::Expr(expr, expr_var));
}
pub fn insert_alias(&mut self, symbol: Symbol, aliases: Symbol) {
self.0.insert(symbol, PartialExprLink::Aliases(aliases));
}
pub fn contains(&self, symbol: Symbol) -> bool {
self.0.contains_key(&symbol)
}
pub fn get(&mut self, mut symbol: Symbol) -> Option<(&roc_can::expr::Expr, Variable)> {
// In practice the alias chain is very short
loop {
match self.0.get(&symbol) {
None => {
return None;
}
Some(&PartialExprLink::Aliases(real_symbol)) => {
symbol = real_symbol;
}
Some(PartialExprLink::Expr(expr, var)) => {
return Some((expr, *var));
}
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CapturedSymbols<'a> {
None,
@ -668,6 +729,7 @@ impl<'a> Specialized<'a> {
#[derive(Clone, Debug)]
pub struct Procs<'a> {
pub partial_procs: PartialProcs<'a>,
pub partial_exprs: PartialExprs,
pub imported_module_thunks: &'a [Symbol],
pub module_thunks: &'a [Symbol],
pending_specializations: PendingSpecializations<'a>,
@ -680,6 +742,7 @@ impl<'a> Procs<'a> {
pub fn new_in(arena: &'a Bump) -> Self {
Self {
partial_procs: PartialProcs::new_in(arena),
partial_exprs: PartialExprs::new_in(arena),
imported_module_thunks: &[],
module_thunks: &[],
pending_specializations: PendingSpecializations::Finding(Suspended::new_in(arena)),
@ -3103,16 +3166,20 @@ pub fn with_hole<'a>(
_ => {}
}
// continue with the default path
let mut stmt = with_hole(
env,
cont.value,
variable,
procs,
layout_cache,
assigned,
hole,
);
let build_rest =
|env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>| {
with_hole(
env,
cont.value,
variable,
procs,
layout_cache,
assigned,
hole,
)
};
// a variable is aliased
if let roc_can::expr::Expr::Var(original) = def.loc_expr.value {
@ -3124,18 +3191,17 @@ pub fn with_hole<'a>(
//
// foo = RBTRee.empty
stmt = handle_variable_aliasing(
handle_variable_aliasing(
env,
procs,
layout_cache,
def.expr_var,
symbol,
original,
stmt,
);
stmt
build_rest,
)
} else {
let rest = build_rest(env, procs, layout_cache);
with_hole(
env,
def.loc_expr.value,
@ -3143,7 +3209,7 @@ pub fn with_hole<'a>(
procs,
layout_cache,
symbol,
env.arena.alloc(stmt),
env.arena.alloc(rest),
)
}
} else {
@ -3328,6 +3394,7 @@ pub fn with_hole<'a>(
let mut can_fields = Vec::with_capacity_in(fields.len(), env.arena);
enum Field {
// TODO: rename this since it can handle unspecialized expressions now too
Function(Symbol, Variable),
ValueSymbol,
Field(roc_can::expr::Field),
@ -3338,7 +3405,7 @@ pub fn with_hole<'a>(
use ReuseSymbol::*;
match fields.remove(&label) {
Some(field) => match can_reuse_symbol(env, procs, &field.loc_expr.value) {
Imported(symbol) | LocalFunction(symbol) => {
Imported(symbol) | LocalFunction(symbol) | UnspecializedExpr(symbol) => {
field_symbols.push(symbol);
can_fields.push(Field::Function(symbol, variable));
}
@ -4064,6 +4131,9 @@ pub fn with_hole<'a>(
LocalFunction(_) => {
unreachable!("if this was known to be a function, we would not be here")
}
UnspecializedExpr(_) => {
unreachable!("if this was known to be an unspecialized expression, we would not be here")
}
Imported(thunk_name) => {
debug_assert!(procs.is_imported_module_thunk(thunk_name));
@ -5023,6 +5093,31 @@ fn register_capturing_closure<'a>(
}
}
fn is_literal_like(expr: &roc_can::expr::Expr) -> bool {
use roc_can::expr::Expr::*;
matches!(
expr,
Num(..)
| Int(..)
| Float(..)
| List { .. }
| Str(_)
| ZeroArgumentTag { .. }
| Tag { .. }
| Record { .. }
)
}
fn expr_is_polymorphic<'a>(
env: &mut Env<'a, '_>,
expr: &roc_can::expr::Expr,
expr_var: Variable,
) -> bool {
// TODO: I don't think we need the `is_literal_like` check, but taking it slow for now...
let is_flex_or_rigid = |c: &Content| matches!(c, Content::FlexVar(_) | Content::RigidVar(_));
is_literal_like(expr) && env.subs.var_contains_content(expr_var, is_flex_or_rigid)
}
pub fn from_can<'a>(
env: &mut Env<'a, '_>,
variable: Variable,
@ -5177,19 +5272,26 @@ pub fn from_can<'a>(
// or
//
// foo = RBTRee.empty
let mut rest = from_can(env, def.expr_var, cont.value, procs, layout_cache);
rest = handle_variable_aliasing(
// TODO: right now we need help out rustc with the closure types;
// it isn't able to infer the right lifetime bounds. See if we
// can remove the annotations in the future.
let build_rest =
|env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>| {
from_can(env, def.expr_var, cont.value, procs, layout_cache)
};
return handle_variable_aliasing(
env,
procs,
layout_cache,
def.expr_var,
*symbol,
original,
rest,
build_rest,
);
return rest;
}
roc_can::expr::Expr::LetNonRec(nested_def, nested_cont, nested_annotation) => {
use roc_can::expr::Expr::*;
@ -5273,6 +5375,21 @@ pub fn from_can<'a>(
return from_can(env, variable, new_outer, procs, layout_cache);
}
ref body if expr_is_polymorphic(env, body, def.expr_var) => {
// This is a pattern like
//
// n = 1
// asU8 n
//
// At the definition site `n = 1` we only know `1` to have the type `[Int *]`,
// which won't be refined until the call `asU8 n`. Add it as a partial expression
// that will be specialized at each concrete usage site.
procs
.partial_exprs
.insert(*symbol, def.loc_expr.value, def.expr_var);
return from_can(env, variable, cont.value, procs, layout_cache);
}
_ => {
let rest = from_can(env, variable, cont.value, procs, layout_cache);
return with_hole(
@ -6290,6 +6407,7 @@ enum ReuseSymbol {
Imported(Symbol),
LocalFunction(Symbol),
Value(Symbol),
UnspecializedExpr(Symbol),
NotASymbol,
}
@ -6307,6 +6425,8 @@ fn can_reuse_symbol<'a>(
Imported(symbol)
} else if procs.partial_procs.contains_key(symbol) {
LocalFunction(symbol)
} else if procs.partial_exprs.contains(symbol) {
UnspecializedExpr(symbol)
} else {
Value(symbol)
}
@ -6326,15 +6446,29 @@ fn possible_reuse_symbol<'a>(
}
}
fn handle_variable_aliasing<'a>(
fn handle_variable_aliasing<'a, BuildRest>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
variable: Variable,
left: Symbol,
right: Symbol,
mut result: Stmt<'a>,
) -> Stmt<'a> {
build_rest: BuildRest,
) -> Stmt<'a>
where
BuildRest: FnOnce(&mut Env<'a, '_>, &mut Procs<'a>, &mut LayoutCache<'a>) -> Stmt<'a>,
{
if procs.partial_exprs.contains(right) {
// If `right` links to a partial expression, make sure we link `left` to it as well, so
// that usages of it will be specialized when building the rest of the program.
procs.partial_exprs.insert_alias(left, right);
return build_rest(env, procs, layout_cache);
}
// Otherwise we're dealing with an alias to something that doesn't need to be specialized, or
// whose usages will already be specialized in the rest of the program. Let's just build the
// rest of the program now to get our hole.
let mut result = build_rest(env, procs, layout_cache);
if procs.is_imported_module_thunk(right) {
// if this is an imported symbol, then we must make sure it is
// specialized, and wrap the original in a function pointer.
@ -6392,6 +6526,7 @@ fn let_empty_struct<'a>(assigned: Symbol, hole: &'a Stmt<'a>) -> Stmt<'a> {
}
/// If the symbol is a function, make sure it is properly specialized
// TODO: rename this now that we handle polymorphic non-function expressions too
fn reuse_function_symbol<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
@ -6401,6 +6536,35 @@ fn reuse_function_symbol<'a>(
result: Stmt<'a>,
original: Symbol,
) -> Stmt<'a> {
if let Some((expr, expr_var)) = procs.partial_exprs.get(original) {
// Specialize the expression type now, based off the `arg_var` we've been given.
// TODO: cache the specialized result
let snapshot = env.subs.snapshot();
let cache_snapshot = layout_cache.snapshot();
let _unified = roc_unify::unify::unify(
env.subs,
arg_var.unwrap(),
expr_var,
roc_unify::unify::Mode::Eq,
);
let result = with_hole(
env,
expr.clone(),
expr_var,
procs,
layout_cache,
symbol,
env.arena.alloc(result),
);
// Restore the prior state so as not to interfere with future specializations.
env.subs.rollback_to(snapshot);
layout_cache.rollback_to(cache_snapshot);
return result;
}
match procs.get_partial_proc(original) {
None => {
match arg_var {
@ -6566,7 +6730,7 @@ fn assign_to_symbol<'a>(
) -> Stmt<'a> {
use ReuseSymbol::*;
match can_reuse_symbol(env, procs, &loc_arg.value) {
Imported(original) | LocalFunction(original) => {
Imported(original) | LocalFunction(original) | UnspecializedExpr(original) => {
// for functions we must make sure they are specialized correctly
reuse_function_symbol(
env,