Merge pull request #3738 from rtfeldman/i3444

Layout generation for recursive lambda sets
This commit is contained in:
Folkert de Vries 2022-08-11 10:22:07 +02:00 committed by GitHub
commit ae0e90c8f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 441 additions and 110 deletions

View file

@ -3130,7 +3130,12 @@ fn specialize_external<'a>(
tag_id,
..
} => {
debug_assert!(matches!(union_layout, UnionLayout::NonRecursive(_)));
debug_assert!(matches!(
union_layout,
UnionLayout::NonRecursive(_)
| UnionLayout::Recursive(_)
| UnionLayout::NullableUnwrapped { .. }
));
debug_assert_eq!(field_layouts.len(), captured.len());
// captured variables are in symbol-alphabetic order, but now we want
@ -3149,7 +3154,9 @@ fn specialize_external<'a>(
size2.cmp(&size1)
});
for (index, (symbol, layout)) in combined.iter().enumerate() {
for (index, (symbol, _)) in combined.iter().enumerate() {
let layout = union_layout.layout_at(tag_id, index);
let expr = Expr::UnionAtIndex {
tag_id,
structure: Symbol::ARG_CLOSURE,
@ -3162,7 +3169,7 @@ fn specialize_external<'a>(
specialized_body = Stmt::Let(
symbol,
expr,
**layout,
layout,
env.arena.alloc(specialized_body),
);
}

View file

@ -10,7 +10,8 @@ use roc_problem::can::RuntimeError;
use roc_target::{PtrWidth, TargetInfo};
use roc_types::num::NumericRange;
use roc_types::subs::{
self, Content, FlatType, Label, RecordFields, Subs, UnionTags, UnsortedUnionLabels, Variable,
self, Content, FlatType, Label, OptVariable, RecordFields, Subs, UnionTags,
UnsortedUnionLabels, Variable,
};
use roc_types::types::{gather_fields_unsorted_iter, RecordField, RecordFieldsError};
use std::cmp::Ordering;
@ -842,6 +843,7 @@ impl<'a> LambdaSet<'a> {
let comparator = |other_name: Symbol, other_captures_layouts: &[Layout]| {
other_name == lambda_name.name
// Make sure all captures are equal
&& other_captures_layouts
.iter()
.eq(lambda_name.captures_niche.0)
@ -859,14 +861,25 @@ impl<'a> LambdaSet<'a> {
debug_assert!(self.contains(function_symbol), "function symbol not in set");
let comparator = |other_name: Symbol, other_captures_layouts: &[Layout]| {
other_name == function_symbol && other_captures_layouts.iter().eq(captures_layouts)
other_name == function_symbol
&& other_captures_layouts
.iter()
.zip(captures_layouts)
.all(|(other_layout, layout)| self.capture_layouts_eq(other_layout, layout))
};
let (name, layouts) = self
.set
.iter()
.find(|(name, layouts)| comparator(*name, layouts))
.expect("no lambda set found");
.unwrap_or_else(|| {
internal_error!(
"no lambda set found for ({:?}, {:#?}): {:#?}",
function_symbol,
captures_layouts,
self
)
});
LambdaName {
name: *name,
@ -874,6 +887,38 @@ impl<'a> LambdaSet<'a> {
}
}
/// Checks if two captured layouts are equivalent under the current lambda set.
/// Resolves recursive pointers to the layout of the lambda set.
fn capture_layouts_eq(&self, left: &Layout, right: &Layout) -> bool {
if left == right {
return true;
}
let left = if left == &Layout::RecursivePointer {
let runtime_repr = self.runtime_representation();
debug_assert!(matches!(
runtime_repr,
Layout::Union(UnionLayout::Recursive(_) | UnionLayout::NullableUnwrapped { .. })
));
Layout::LambdaSet(*self)
} else {
*left
};
let right = if right == &Layout::RecursivePointer {
let runtime_repr = self.runtime_representation();
debug_assert!(matches!(
runtime_repr,
Layout::Union(UnionLayout::Recursive(_) | UnionLayout::NullableUnwrapped { .. })
));
Layout::LambdaSet(*self)
} else {
*right
};
left == right
}
fn layout_for_member<F>(&self, comparator: F) -> ClosureRepresentation<'a>
where
F: Fn(Symbol, &[Layout]) -> bool,
@ -902,16 +947,48 @@ impl<'a> LambdaSet<'a> {
union_layout: *union,
}
}
UnionLayout::Recursive(_) => todo!("recursive closures"),
UnionLayout::Recursive(_) => {
let (index, (name, fields)) = self
.set
.iter()
.enumerate()
.find(|(_, (s, layouts))| comparator(*s, layouts))
.unwrap();
let closure_name = *name;
ClosureRepresentation::Union {
tag_id: index as TagIdIntType,
alphabetic_order_fields: fields,
closure_name,
union_layout: *union,
}
}
UnionLayout::NullableUnwrapped {
nullable_id: _,
other_fields: _,
} => {
let (index, (name, fields)) = self
.set
.iter()
.enumerate()
.find(|(_, (s, layouts))| comparator(*s, layouts))
.unwrap();
let closure_name = *name;
ClosureRepresentation::Union {
tag_id: index as TagIdIntType,
alphabetic_order_fields: fields,
closure_name,
union_layout: *union,
}
}
UnionLayout::NonNullableUnwrapped(_) => todo!("recursive closures"),
UnionLayout::NullableWrapped {
nullable_id: _,
other_tags: _,
} => todo!("recursive closures"),
UnionLayout::NullableUnwrapped {
nullable_id: _,
other_fields: _,
} => todo!("recursive closures"),
}
}
Layout::Struct { .. } => {
@ -971,7 +1048,7 @@ impl<'a> LambdaSet<'a> {
target_info: TargetInfo,
) -> Result<Self, LayoutProblem> {
match resolve_lambda_set(subs, closure_var) {
ResolvedLambdaSet::Set(mut lambdas) => {
ResolvedLambdaSet::Set(mut lambdas, opt_recursion_var) => {
// sort the tags; make sure ordering stays intact!
lambdas.sort_by_key(|(sym, _)| *sym);
@ -992,6 +1069,9 @@ impl<'a> LambdaSet<'a> {
seen: Vec::new_in(arena),
target_info,
};
if let Some(rec_var) = opt_recursion_var.into_variable() {
env.insert_seen(rec_var);
}
for var in variables {
arguments.push(Layout::from_var(&mut env, *var)?);
@ -1048,6 +1128,7 @@ impl<'a> LambdaSet<'a> {
arena,
subs,
set_with_variables,
opt_recursion_var.into_variable(),
target_info,
));
@ -1071,10 +1152,28 @@ impl<'a> LambdaSet<'a> {
arena: &'a Bump,
subs: &Subs,
tags: std::vec::Vec<(Symbol, std::vec::Vec<Variable>)>,
opt_rec_var: Option<Variable>,
target_info: TargetInfo,
) -> Layout<'a> {
if let Some(rec_var) = opt_rec_var {
let tags: std::vec::Vec<_> = tags
.iter()
.map(|(sym, vars)| (sym, vars.as_slice()))
.collect();
let tags = UnsortedUnionLabels { tags };
let mut env = Env {
seen: Vec::new_in(arena),
target_info,
arena,
subs,
};
return layout_from_recursive_union(&mut env, rec_var, &tags)
.expect("unable to create lambda set representation");
}
// otherwise, this is a closure with a payload
let variant = union_sorted_tags_help(arena, tags, None, subs, target_info);
let variant = union_sorted_tags_help(arena, tags, opt_rec_var, subs, target_info);
use UnionVariant::*;
match variant {
@ -1108,7 +1207,12 @@ impl<'a> LambdaSet<'a> {
Layout::Union(UnionLayout::NonRecursive(tag_arguments.into_bump_slice()))
}
_ => panic!("handle recursive layouts"),
Recursive { .. }
| NullableUnwrapped { .. }
| NullableWrapped { .. }
| NonNullableUnwrapped { .. } => {
internal_error!("Recursive layouts should be produced in an earlier branch")
}
}
}
}
@ -1130,7 +1234,10 @@ impl<'a> LambdaSet<'a> {
}
enum ResolvedLambdaSet {
Set(std::vec::Vec<(Symbol, std::vec::Vec<Variable>)>),
Set(
std::vec::Vec<(Symbol, std::vec::Vec<Variable>)>,
OptVariable,
),
/// TODO: figure out if this can happen in a correct program, or is the result of a bug in our
/// compiler. See https://github.com/rtfeldman/roc/issues/3163.
Unbound,
@ -1142,7 +1249,7 @@ fn resolve_lambda_set(subs: &Subs, mut var: Variable) -> ResolvedLambdaSet {
match subs.get_content_without_compacting(var) {
Content::LambdaSet(subs::LambdaSet {
solved,
recursion_var: _,
recursion_var,
unspecialized,
ambient_function: _,
}) => {
@ -1153,7 +1260,7 @@ fn resolve_lambda_set(subs: &Subs, mut var: Variable) -> ResolvedLambdaSet {
subs.uls_of_var
);
roc_types::pretty_print::push_union(subs, solved, &mut set);
return ResolvedLambdaSet::Set(set);
return ResolvedLambdaSet::Set(set, *recursion_var);
}
Content::RecursionVar { structure, .. } => {
var = *structure;
@ -2130,10 +2237,14 @@ fn layout_from_flat_type<'a>(
}
}
Func(_, closure_var, _) => {
let lambda_set =
LambdaSet::from_var(env.arena, env.subs, closure_var, env.target_info)?;
if env.is_seen(closure_var) {
Ok(Layout::RecursivePointer)
} else {
let lambda_set =
LambdaSet::from_var(env.arena, env.subs, closure_var, env.target_info)?;
Ok(Layout::LambdaSet(lambda_set))
Ok(Layout::LambdaSet(lambda_set))
}
}
Record(fields, ext_var) => {
// extract any values from the ext_var

View file

@ -713,12 +713,13 @@ fn solve(
new_env.insert_symbol_var_if_vacant(*symbol, loc_var.value);
}
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
stack.push(Work::Constraint {
env: arena.alloc(new_env),
rank,
constraint: ret_constraint,
});
// Check for infinite types first
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
continue;
}
@ -831,12 +832,13 @@ fn solve(
// Now solve the body, using the new vars_by_symbol which includes
// the assignments' name-to-variable mappings.
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
stack.push(Work::Constraint {
env: arena.alloc(new_env),
rank,
constraint: ret_constraint,
});
// Check for infinite types first
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
state = state_for_ret_con;
@ -2874,28 +2876,34 @@ fn check_for_infinite_type(
) {
let var = loc_var.value;
while let Err((recursive, _chain)) = subs.occurs(var) {
// try to make a union recursive, see if that helps
match subs.get_content_without_compacting(recursive) {
&Content::Structure(FlatType::TagUnion(tags, ext_var)) => {
subs.mark_tag_union_recursive(recursive, tags, ext_var);
}
&Content::LambdaSet(subs::LambdaSet {
solved,
recursion_var: _,
unspecialized,
ambient_function: ambient_function_var,
}) => {
subs.mark_lambda_set_recursive(
recursive,
'next_occurs_check: while let Err((_, chain)) = subs.occurs(var) {
// walk the chain till we find a tag union or lambda set, starting from the variable that
// occurred recursively, which is always at the end of the chain.
for &var in chain.iter().rev() {
match *subs.get_content_without_compacting(var) {
Content::Structure(FlatType::TagUnion(tags, ext_var)) => {
subs.mark_tag_union_recursive(var, tags, ext_var);
continue 'next_occurs_check;
}
Content::LambdaSet(subs::LambdaSet {
solved,
recursion_var: _,
unspecialized,
ambient_function_var,
);
ambient_function: ambient_function_var,
}) => {
subs.mark_lambda_set_recursive(
var,
solved,
unspecialized,
ambient_function_var,
);
continue 'next_occurs_check;
}
_ => { /* fall through */ }
}
_other => circular_error(subs, problems, symbol, &loc_var),
}
circular_error(subs, problems, symbol, &loc_var);
}
}

View file

@ -7657,4 +7657,39 @@ mod solve_expr {
"Num *",
);
}
#[test]
fn issue_3444() {
infer_queries!(
indoc!(
r#"
compose = \f, g ->
closCompose = \x -> g (f x)
closCompose
const = \x ->
closConst = \_ -> x
closConst
list = []
res : Str -> Str
res = List.walk list (const "z") (\c1, c2 -> compose c1 c2)
# ^^^^^ ^^^^^^^
# ^^^^^^^^^^^^^^^^^^^^^^^^
#^^^{-1}
res "hello"
#^^^{-1}
"#
),
@r###"
const : Str -[[const(2)]]-> (Str -[[closCompose(7) (Str -a-> Str) (Str -[[]]-> Str), closConst(10) Str] as a]-> Str)
compose : (Str -a-> Str), (Str -[[]]-> Str) -[[compose(1)]]-> (Str -a-> Str)
\c1, c2 -> compose c1 c2 : (Str -a-> Str), (Str -[[]]-> Str) -[[11(11)]]-> (Str -a-> Str)
res : Str -[[closCompose(7) (Str -a-> Str) (Str -[[]]-> Str), closConst(10) Str] as a]-> Str
res : Str -[[closCompose(7) (Str -a-> Str) (Str -[[]]-> Str), closConst(10) Str] as a]-> Str
"###
);
}
}

View file

@ -3705,3 +3705,136 @@ fn runtime_error_when_degenerate_pattern_reached() {
true // allow errors
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn recursive_lambda_set_issue_3444() {
assert_evals_to!(
indoc!(
r#"
combine = \f, g -> \x -> g (f x)
const = \x -> (\_y -> x)
list = [const "a", const "b", const "c"]
res : Str -> Str
res = List.walk list (const "z") (\c1, c2 -> combine c1 c2)
res "hello"
"#
),
RocStr::from("c"),
RocStr
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn recursive_lambda_set_toplevel_issue_3444() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [main] to "./platform"
combine = \f, g -> \x -> g (f x)
const = \x -> (\_y -> x)
list = [const "a", const "b", const "c"]
res : Str -> Str
res = List.walk list (const "z") (\c1, c2 -> combine c1 c2)
main = res "hello"
"#
),
RocStr::from("c"),
RocStr
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn recursive_lambda_set_issue_3444_inferred() {
assert_evals_to!(
indoc!(
r#"
combine = \f, g -> \x -> g (f x)
const = \x -> (\_y -> x)
list = [const "a", const "b", const "c"]
res = List.walk list (const "z") (\c1, c2 -> combine c1 c2)
res "hello"
"#
),
RocStr::from("c"),
RocStr
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn compose_recursive_lambda_set_productive_toplevel() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [main] to "./platform"
compose = \f, g -> \x -> g (f x)
identity = \x -> x
exclaim = \s -> "\(s)!"
whisper = \s -> "(\(s))"
main =
res: Str -> Str
res = List.walk [ exclaim, whisper ] identity compose
res "hello"
"#
),
RocStr::from("(hello!)"),
RocStr
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn compose_recursive_lambda_set_productive_nested() {
assert_evals_to!(
indoc!(
r#"
compose = \f, g -> \x -> g (f x)
identity = \x -> x
exclaim = \s -> "\(s)!"
whisper = \s -> "(\(s))"
res: Str -> Str
res = List.walk [ exclaim, whisper ] identity compose
res "hello"
"#
),
RocStr::from("(hello!)"),
RocStr
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn compose_recursive_lambda_set_productive_inferred() {
assert_evals_to!(
indoc!(
r#"
compose = \f, g -> \x -> g (f x)
identity = \x -> x
exclaim = \s -> "\(s)!"
whisper = \s -> "(\(s))"
res = List.walk [ exclaim, whisper ] identity compose
res "hello"
"#
),
RocStr::from("(hello!)"),
RocStr
);
}

View file

@ -2518,10 +2518,12 @@ fn unify_flat_type<M: MetaCollector>(
outcome.union(arg_outcome);
if outcome.mismatches.is_empty() {
let merged_closure_var = choose_merged_var(env.subs, *l_closure, *r_closure);
outcome.union(merge(
env,
ctx,
Structure(Func(*r_args, *r_closure, *r_ret)),
Structure(Func(*r_args, merged_closure_var, *r_ret)),
));
}