mirror of
https://github.com/roc-lang/roc.git
synced 2025-07-24 06:55:15 +00:00
Merge pull request #3738 from rtfeldman/i3444
Layout generation for recursive lambda sets
This commit is contained in:
commit
ae0e90c8f3
8 changed files with 441 additions and 110 deletions
|
@ -3130,7 +3130,12 @@ fn specialize_external<'a>(
|
|||
tag_id,
|
||||
..
|
||||
} => {
|
||||
debug_assert!(matches!(union_layout, UnionLayout::NonRecursive(_)));
|
||||
debug_assert!(matches!(
|
||||
union_layout,
|
||||
UnionLayout::NonRecursive(_)
|
||||
| UnionLayout::Recursive(_)
|
||||
| UnionLayout::NullableUnwrapped { .. }
|
||||
));
|
||||
debug_assert_eq!(field_layouts.len(), captured.len());
|
||||
|
||||
// captured variables are in symbol-alphabetic order, but now we want
|
||||
|
@ -3149,7 +3154,9 @@ fn specialize_external<'a>(
|
|||
size2.cmp(&size1)
|
||||
});
|
||||
|
||||
for (index, (symbol, layout)) in combined.iter().enumerate() {
|
||||
for (index, (symbol, _)) in combined.iter().enumerate() {
|
||||
let layout = union_layout.layout_at(tag_id, index);
|
||||
|
||||
let expr = Expr::UnionAtIndex {
|
||||
tag_id,
|
||||
structure: Symbol::ARG_CLOSURE,
|
||||
|
@ -3162,7 +3169,7 @@ fn specialize_external<'a>(
|
|||
specialized_body = Stmt::Let(
|
||||
symbol,
|
||||
expr,
|
||||
**layout,
|
||||
layout,
|
||||
env.arena.alloc(specialized_body),
|
||||
);
|
||||
}
|
||||
|
|
|
@ -10,7 +10,8 @@ use roc_problem::can::RuntimeError;
|
|||
use roc_target::{PtrWidth, TargetInfo};
|
||||
use roc_types::num::NumericRange;
|
||||
use roc_types::subs::{
|
||||
self, Content, FlatType, Label, RecordFields, Subs, UnionTags, UnsortedUnionLabels, Variable,
|
||||
self, Content, FlatType, Label, OptVariable, RecordFields, Subs, UnionTags,
|
||||
UnsortedUnionLabels, Variable,
|
||||
};
|
||||
use roc_types::types::{gather_fields_unsorted_iter, RecordField, RecordFieldsError};
|
||||
use std::cmp::Ordering;
|
||||
|
@ -842,6 +843,7 @@ impl<'a> LambdaSet<'a> {
|
|||
|
||||
let comparator = |other_name: Symbol, other_captures_layouts: &[Layout]| {
|
||||
other_name == lambda_name.name
|
||||
// Make sure all captures are equal
|
||||
&& other_captures_layouts
|
||||
.iter()
|
||||
.eq(lambda_name.captures_niche.0)
|
||||
|
@ -859,14 +861,25 @@ impl<'a> LambdaSet<'a> {
|
|||
debug_assert!(self.contains(function_symbol), "function symbol not in set");
|
||||
|
||||
let comparator = |other_name: Symbol, other_captures_layouts: &[Layout]| {
|
||||
other_name == function_symbol && other_captures_layouts.iter().eq(captures_layouts)
|
||||
other_name == function_symbol
|
||||
&& other_captures_layouts
|
||||
.iter()
|
||||
.zip(captures_layouts)
|
||||
.all(|(other_layout, layout)| self.capture_layouts_eq(other_layout, layout))
|
||||
};
|
||||
|
||||
let (name, layouts) = self
|
||||
.set
|
||||
.iter()
|
||||
.find(|(name, layouts)| comparator(*name, layouts))
|
||||
.expect("no lambda set found");
|
||||
.unwrap_or_else(|| {
|
||||
internal_error!(
|
||||
"no lambda set found for ({:?}, {:#?}): {:#?}",
|
||||
function_symbol,
|
||||
captures_layouts,
|
||||
self
|
||||
)
|
||||
});
|
||||
|
||||
LambdaName {
|
||||
name: *name,
|
||||
|
@ -874,6 +887,38 @@ impl<'a> LambdaSet<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Checks if two captured layouts are equivalent under the current lambda set.
|
||||
/// Resolves recursive pointers to the layout of the lambda set.
|
||||
fn capture_layouts_eq(&self, left: &Layout, right: &Layout) -> bool {
|
||||
if left == right {
|
||||
return true;
|
||||
}
|
||||
|
||||
let left = if left == &Layout::RecursivePointer {
|
||||
let runtime_repr = self.runtime_representation();
|
||||
debug_assert!(matches!(
|
||||
runtime_repr,
|
||||
Layout::Union(UnionLayout::Recursive(_) | UnionLayout::NullableUnwrapped { .. })
|
||||
));
|
||||
Layout::LambdaSet(*self)
|
||||
} else {
|
||||
*left
|
||||
};
|
||||
|
||||
let right = if right == &Layout::RecursivePointer {
|
||||
let runtime_repr = self.runtime_representation();
|
||||
debug_assert!(matches!(
|
||||
runtime_repr,
|
||||
Layout::Union(UnionLayout::Recursive(_) | UnionLayout::NullableUnwrapped { .. })
|
||||
));
|
||||
Layout::LambdaSet(*self)
|
||||
} else {
|
||||
*right
|
||||
};
|
||||
|
||||
left == right
|
||||
}
|
||||
|
||||
fn layout_for_member<F>(&self, comparator: F) -> ClosureRepresentation<'a>
|
||||
where
|
||||
F: Fn(Symbol, &[Layout]) -> bool,
|
||||
|
@ -902,16 +947,48 @@ impl<'a> LambdaSet<'a> {
|
|||
union_layout: *union,
|
||||
}
|
||||
}
|
||||
UnionLayout::Recursive(_) => todo!("recursive closures"),
|
||||
UnionLayout::Recursive(_) => {
|
||||
let (index, (name, fields)) = self
|
||||
.set
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, (s, layouts))| comparator(*s, layouts))
|
||||
.unwrap();
|
||||
|
||||
let closure_name = *name;
|
||||
|
||||
ClosureRepresentation::Union {
|
||||
tag_id: index as TagIdIntType,
|
||||
alphabetic_order_fields: fields,
|
||||
closure_name,
|
||||
union_layout: *union,
|
||||
}
|
||||
}
|
||||
UnionLayout::NullableUnwrapped {
|
||||
nullable_id: _,
|
||||
other_fields: _,
|
||||
} => {
|
||||
let (index, (name, fields)) = self
|
||||
.set
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, (s, layouts))| comparator(*s, layouts))
|
||||
.unwrap();
|
||||
|
||||
let closure_name = *name;
|
||||
|
||||
ClosureRepresentation::Union {
|
||||
tag_id: index as TagIdIntType,
|
||||
alphabetic_order_fields: fields,
|
||||
closure_name,
|
||||
union_layout: *union,
|
||||
}
|
||||
}
|
||||
UnionLayout::NonNullableUnwrapped(_) => todo!("recursive closures"),
|
||||
UnionLayout::NullableWrapped {
|
||||
nullable_id: _,
|
||||
other_tags: _,
|
||||
} => todo!("recursive closures"),
|
||||
UnionLayout::NullableUnwrapped {
|
||||
nullable_id: _,
|
||||
other_fields: _,
|
||||
} => todo!("recursive closures"),
|
||||
}
|
||||
}
|
||||
Layout::Struct { .. } => {
|
||||
|
@ -971,7 +1048,7 @@ impl<'a> LambdaSet<'a> {
|
|||
target_info: TargetInfo,
|
||||
) -> Result<Self, LayoutProblem> {
|
||||
match resolve_lambda_set(subs, closure_var) {
|
||||
ResolvedLambdaSet::Set(mut lambdas) => {
|
||||
ResolvedLambdaSet::Set(mut lambdas, opt_recursion_var) => {
|
||||
// sort the tags; make sure ordering stays intact!
|
||||
lambdas.sort_by_key(|(sym, _)| *sym);
|
||||
|
||||
|
@ -992,6 +1069,9 @@ impl<'a> LambdaSet<'a> {
|
|||
seen: Vec::new_in(arena),
|
||||
target_info,
|
||||
};
|
||||
if let Some(rec_var) = opt_recursion_var.into_variable() {
|
||||
env.insert_seen(rec_var);
|
||||
}
|
||||
|
||||
for var in variables {
|
||||
arguments.push(Layout::from_var(&mut env, *var)?);
|
||||
|
@ -1048,6 +1128,7 @@ impl<'a> LambdaSet<'a> {
|
|||
arena,
|
||||
subs,
|
||||
set_with_variables,
|
||||
opt_recursion_var.into_variable(),
|
||||
target_info,
|
||||
));
|
||||
|
||||
|
@ -1071,10 +1152,28 @@ impl<'a> LambdaSet<'a> {
|
|||
arena: &'a Bump,
|
||||
subs: &Subs,
|
||||
tags: std::vec::Vec<(Symbol, std::vec::Vec<Variable>)>,
|
||||
opt_rec_var: Option<Variable>,
|
||||
target_info: TargetInfo,
|
||||
) -> Layout<'a> {
|
||||
if let Some(rec_var) = opt_rec_var {
|
||||
let tags: std::vec::Vec<_> = tags
|
||||
.iter()
|
||||
.map(|(sym, vars)| (sym, vars.as_slice()))
|
||||
.collect();
|
||||
let tags = UnsortedUnionLabels { tags };
|
||||
let mut env = Env {
|
||||
seen: Vec::new_in(arena),
|
||||
target_info,
|
||||
arena,
|
||||
subs,
|
||||
};
|
||||
|
||||
return layout_from_recursive_union(&mut env, rec_var, &tags)
|
||||
.expect("unable to create lambda set representation");
|
||||
}
|
||||
|
||||
// otherwise, this is a closure with a payload
|
||||
let variant = union_sorted_tags_help(arena, tags, None, subs, target_info);
|
||||
let variant = union_sorted_tags_help(arena, tags, opt_rec_var, subs, target_info);
|
||||
|
||||
use UnionVariant::*;
|
||||
match variant {
|
||||
|
@ -1108,7 +1207,12 @@ impl<'a> LambdaSet<'a> {
|
|||
Layout::Union(UnionLayout::NonRecursive(tag_arguments.into_bump_slice()))
|
||||
}
|
||||
|
||||
_ => panic!("handle recursive layouts"),
|
||||
Recursive { .. }
|
||||
| NullableUnwrapped { .. }
|
||||
| NullableWrapped { .. }
|
||||
| NonNullableUnwrapped { .. } => {
|
||||
internal_error!("Recursive layouts should be produced in an earlier branch")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1130,7 +1234,10 @@ impl<'a> LambdaSet<'a> {
|
|||
}
|
||||
|
||||
enum ResolvedLambdaSet {
|
||||
Set(std::vec::Vec<(Symbol, std::vec::Vec<Variable>)>),
|
||||
Set(
|
||||
std::vec::Vec<(Symbol, std::vec::Vec<Variable>)>,
|
||||
OptVariable,
|
||||
),
|
||||
/// TODO: figure out if this can happen in a correct program, or is the result of a bug in our
|
||||
/// compiler. See https://github.com/rtfeldman/roc/issues/3163.
|
||||
Unbound,
|
||||
|
@ -1142,7 +1249,7 @@ fn resolve_lambda_set(subs: &Subs, mut var: Variable) -> ResolvedLambdaSet {
|
|||
match subs.get_content_without_compacting(var) {
|
||||
Content::LambdaSet(subs::LambdaSet {
|
||||
solved,
|
||||
recursion_var: _,
|
||||
recursion_var,
|
||||
unspecialized,
|
||||
ambient_function: _,
|
||||
}) => {
|
||||
|
@ -1153,7 +1260,7 @@ fn resolve_lambda_set(subs: &Subs, mut var: Variable) -> ResolvedLambdaSet {
|
|||
subs.uls_of_var
|
||||
);
|
||||
roc_types::pretty_print::push_union(subs, solved, &mut set);
|
||||
return ResolvedLambdaSet::Set(set);
|
||||
return ResolvedLambdaSet::Set(set, *recursion_var);
|
||||
}
|
||||
Content::RecursionVar { structure, .. } => {
|
||||
var = *structure;
|
||||
|
@ -2130,10 +2237,14 @@ fn layout_from_flat_type<'a>(
|
|||
}
|
||||
}
|
||||
Func(_, closure_var, _) => {
|
||||
let lambda_set =
|
||||
LambdaSet::from_var(env.arena, env.subs, closure_var, env.target_info)?;
|
||||
if env.is_seen(closure_var) {
|
||||
Ok(Layout::RecursivePointer)
|
||||
} else {
|
||||
let lambda_set =
|
||||
LambdaSet::from_var(env.arena, env.subs, closure_var, env.target_info)?;
|
||||
|
||||
Ok(Layout::LambdaSet(lambda_set))
|
||||
Ok(Layout::LambdaSet(lambda_set))
|
||||
}
|
||||
}
|
||||
Record(fields, ext_var) => {
|
||||
// extract any values from the ext_var
|
||||
|
|
|
@ -713,12 +713,13 @@ fn solve(
|
|||
new_env.insert_symbol_var_if_vacant(*symbol, loc_var.value);
|
||||
}
|
||||
|
||||
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
|
||||
stack.push(Work::Constraint {
|
||||
env: arena.alloc(new_env),
|
||||
rank,
|
||||
constraint: ret_constraint,
|
||||
});
|
||||
// Check for infinite types first
|
||||
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
@ -831,12 +832,13 @@ fn solve(
|
|||
|
||||
// Now solve the body, using the new vars_by_symbol which includes
|
||||
// the assignments' name-to-variable mappings.
|
||||
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
|
||||
stack.push(Work::Constraint {
|
||||
env: arena.alloc(new_env),
|
||||
rank,
|
||||
constraint: ret_constraint,
|
||||
});
|
||||
// Check for infinite types first
|
||||
stack.push(Work::CheckForInfiniteTypes(local_def_vars));
|
||||
|
||||
state = state_for_ret_con;
|
||||
|
||||
|
@ -2874,28 +2876,34 @@ fn check_for_infinite_type(
|
|||
) {
|
||||
let var = loc_var.value;
|
||||
|
||||
while let Err((recursive, _chain)) = subs.occurs(var) {
|
||||
// try to make a union recursive, see if that helps
|
||||
match subs.get_content_without_compacting(recursive) {
|
||||
&Content::Structure(FlatType::TagUnion(tags, ext_var)) => {
|
||||
subs.mark_tag_union_recursive(recursive, tags, ext_var);
|
||||
}
|
||||
&Content::LambdaSet(subs::LambdaSet {
|
||||
solved,
|
||||
recursion_var: _,
|
||||
unspecialized,
|
||||
ambient_function: ambient_function_var,
|
||||
}) => {
|
||||
subs.mark_lambda_set_recursive(
|
||||
recursive,
|
||||
'next_occurs_check: while let Err((_, chain)) = subs.occurs(var) {
|
||||
// walk the chain till we find a tag union or lambda set, starting from the variable that
|
||||
// occurred recursively, which is always at the end of the chain.
|
||||
for &var in chain.iter().rev() {
|
||||
match *subs.get_content_without_compacting(var) {
|
||||
Content::Structure(FlatType::TagUnion(tags, ext_var)) => {
|
||||
subs.mark_tag_union_recursive(var, tags, ext_var);
|
||||
continue 'next_occurs_check;
|
||||
}
|
||||
Content::LambdaSet(subs::LambdaSet {
|
||||
solved,
|
||||
recursion_var: _,
|
||||
unspecialized,
|
||||
ambient_function_var,
|
||||
);
|
||||
ambient_function: ambient_function_var,
|
||||
}) => {
|
||||
subs.mark_lambda_set_recursive(
|
||||
var,
|
||||
solved,
|
||||
unspecialized,
|
||||
ambient_function_var,
|
||||
);
|
||||
continue 'next_occurs_check;
|
||||
}
|
||||
_ => { /* fall through */ }
|
||||
}
|
||||
|
||||
_other => circular_error(subs, problems, symbol, &loc_var),
|
||||
}
|
||||
|
||||
circular_error(subs, problems, symbol, &loc_var);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7657,4 +7657,39 @@ mod solve_expr {
|
|||
"Num *",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn issue_3444() {
|
||||
infer_queries!(
|
||||
indoc!(
|
||||
r#"
|
||||
compose = \f, g ->
|
||||
closCompose = \x -> g (f x)
|
||||
closCompose
|
||||
|
||||
const = \x ->
|
||||
closConst = \_ -> x
|
||||
closConst
|
||||
|
||||
list = []
|
||||
|
||||
res : Str -> Str
|
||||
res = List.walk list (const "z") (\c1, c2 -> compose c1 c2)
|
||||
# ^^^^^ ^^^^^^^
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
#^^^{-1}
|
||||
|
||||
res "hello"
|
||||
#^^^{-1}
|
||||
"#
|
||||
),
|
||||
@r###"
|
||||
const : Str -[[const(2)]]-> (Str -[[closCompose(7) (Str -a-> Str) (Str -[[]]-> Str), closConst(10) Str] as a]-> Str)
|
||||
compose : (Str -a-> Str), (Str -[[]]-> Str) -[[compose(1)]]-> (Str -a-> Str)
|
||||
\c1, c2 -> compose c1 c2 : (Str -a-> Str), (Str -[[]]-> Str) -[[11(11)]]-> (Str -a-> Str)
|
||||
res : Str -[[closCompose(7) (Str -a-> Str) (Str -[[]]-> Str), closConst(10) Str] as a]-> Str
|
||||
res : Str -[[closCompose(7) (Str -a-> Str) (Str -[[]]-> Str), closConst(10) Str] as a]-> Str
|
||||
"###
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3705,3 +3705,136 @@ fn runtime_error_when_degenerate_pattern_reached() {
|
|||
true // allow errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn recursive_lambda_set_issue_3444() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
combine = \f, g -> \x -> g (f x)
|
||||
const = \x -> (\_y -> x)
|
||||
|
||||
list = [const "a", const "b", const "c"]
|
||||
|
||||
res : Str -> Str
|
||||
res = List.walk list (const "z") (\c1, c2 -> combine c1 c2)
|
||||
res "hello"
|
||||
"#
|
||||
),
|
||||
RocStr::from("c"),
|
||||
RocStr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn recursive_lambda_set_toplevel_issue_3444() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
app "test" provides [main] to "./platform"
|
||||
|
||||
combine = \f, g -> \x -> g (f x)
|
||||
const = \x -> (\_y -> x)
|
||||
|
||||
list = [const "a", const "b", const "c"]
|
||||
|
||||
res : Str -> Str
|
||||
res = List.walk list (const "z") (\c1, c2 -> combine c1 c2)
|
||||
|
||||
main = res "hello"
|
||||
"#
|
||||
),
|
||||
RocStr::from("c"),
|
||||
RocStr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn recursive_lambda_set_issue_3444_inferred() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
combine = \f, g -> \x -> g (f x)
|
||||
const = \x -> (\_y -> x)
|
||||
|
||||
list = [const "a", const "b", const "c"]
|
||||
|
||||
res = List.walk list (const "z") (\c1, c2 -> combine c1 c2)
|
||||
res "hello"
|
||||
"#
|
||||
),
|
||||
RocStr::from("c"),
|
||||
RocStr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn compose_recursive_lambda_set_productive_toplevel() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
app "test" provides [main] to "./platform"
|
||||
|
||||
compose = \f, g -> \x -> g (f x)
|
||||
|
||||
identity = \x -> x
|
||||
exclaim = \s -> "\(s)!"
|
||||
whisper = \s -> "(\(s))"
|
||||
|
||||
main =
|
||||
res: Str -> Str
|
||||
res = List.walk [ exclaim, whisper ] identity compose
|
||||
res "hello"
|
||||
"#
|
||||
),
|
||||
RocStr::from("(hello!)"),
|
||||
RocStr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn compose_recursive_lambda_set_productive_nested() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
compose = \f, g -> \x -> g (f x)
|
||||
|
||||
identity = \x -> x
|
||||
exclaim = \s -> "\(s)!"
|
||||
whisper = \s -> "(\(s))"
|
||||
|
||||
res: Str -> Str
|
||||
res = List.walk [ exclaim, whisper ] identity compose
|
||||
res "hello"
|
||||
"#
|
||||
),
|
||||
RocStr::from("(hello!)"),
|
||||
RocStr
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
|
||||
fn compose_recursive_lambda_set_productive_inferred() {
|
||||
assert_evals_to!(
|
||||
indoc!(
|
||||
r#"
|
||||
compose = \f, g -> \x -> g (f x)
|
||||
|
||||
identity = \x -> x
|
||||
exclaim = \s -> "\(s)!"
|
||||
whisper = \s -> "(\(s))"
|
||||
|
||||
res = List.walk [ exclaim, whisper ] identity compose
|
||||
res "hello"
|
||||
"#
|
||||
),
|
||||
RocStr::from("(hello!)"),
|
||||
RocStr
|
||||
);
|
||||
}
|
||||
|
|
|
@ -2518,10 +2518,12 @@ fn unify_flat_type<M: MetaCollector>(
|
|||
outcome.union(arg_outcome);
|
||||
|
||||
if outcome.mismatches.is_empty() {
|
||||
let merged_closure_var = choose_merged_var(env.subs, *l_closure, *r_closure);
|
||||
|
||||
outcome.union(merge(
|
||||
env,
|
||||
ctx,
|
||||
Structure(Func(*r_args, *r_closure, *r_ret)),
|
||||
Structure(Func(*r_args, merged_closure_var, *r_ret)),
|
||||
));
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue