Merge pull request #3616 from rtfeldman/i3614

Compile branches in the presence of degenerate patterns
This commit is contained in:
Folkert de Vries 2022-07-25 19:45:36 +02:00 committed by GitHub
commit d212dffa1a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 210 additions and 66 deletions

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
def::Def, def::Def,
expr::{AccessorData, ClosureData, Expr, Field, OpaqueWrapFunctionData}, expr::{AccessorData, ClosureData, Expr, Field, OpaqueWrapFunctionData, WhenBranchPattern},
pattern::{DestructType, Pattern, RecordDestruct}, pattern::{DestructType, Pattern, RecordDestruct},
}; };
use roc_module::{ use roc_module::{
@ -295,7 +295,16 @@ fn deep_copy_expr_help<C: CopyEnv>(env: &mut C, copied: &mut Vec<Variable>, expr
}| crate::expr::WhenBranch { }| crate::expr::WhenBranch {
patterns: patterns patterns: patterns
.iter() .iter()
.map(|lp| lp.map(|p| deep_copy_pattern_help(env, copied, p))) .map(
|WhenBranchPattern {
pattern,
degenerate,
}| WhenBranchPattern {
pattern: pattern
.map(|p| deep_copy_pattern_help(env, copied, p)),
degenerate: *degenerate,
},
)
.collect(), .collect(),
value: value.map(|e| go_help!(e)), value: value.map(|e| go_help!(e)),
guard: guard.as_ref().map(|le| le.map(|e| go_help!(e))), guard: guard.as_ref().map(|le| le.map(|e| go_help!(e))),

View file

@ -1,6 +1,6 @@
use crate::annotation::IntroducedVariables; use crate::annotation::IntroducedVariables;
use crate::def::Def; use crate::def::Def;
use crate::expr::{AnnotatedMark, ClosureData, Declarations, Expr, Recursive}; use crate::expr::{AnnotatedMark, ClosureData, Declarations, Expr, Recursive, WhenBranchPattern};
use crate::pattern::Pattern; use crate::pattern::Pattern;
use crate::scope::Scope; use crate::scope::Scope;
use roc_collections::{SendMap, VecSet}; use roc_collections::{SendMap, VecSet};
@ -475,11 +475,15 @@ fn build_effect_after(
type_arguments, type_arguments,
lambda_set_variables, lambda_set_variables,
}; };
let pattern = WhenBranchPattern {
pattern: Loc::at_zero(pattern),
degenerate: false,
};
let branches = vec![crate::expr::WhenBranch { let branches = vec![crate::expr::WhenBranch {
guard: None, guard: None,
value: Loc::at_zero(force_inner_thunk_call), value: Loc::at_zero(force_inner_thunk_call),
patterns: vec![Loc::at_zero(pattern)], patterns: vec![pattern],
redundant: RedundantMark::new(var_store), redundant: RedundantMark::new(var_store),
}]; }];
@ -1256,9 +1260,13 @@ fn build_effect_loop_inner_body(
let step_tag_name = TagName("Step".into()); let step_tag_name = TagName("Step".into());
let step_pattern = applied_tag_pattern(step_tag_name, &[new_state_symbol], var_store); let step_pattern = applied_tag_pattern(step_tag_name, &[new_state_symbol], var_store);
let step_pattern = WhenBranchPattern {
pattern: Loc::at_zero(step_pattern),
degenerate: false,
};
crate::expr::WhenBranch { crate::expr::WhenBranch {
patterns: vec![Loc::at_zero(step_pattern)], patterns: vec![step_pattern],
value: Loc::at_zero(force_thunk2), value: Loc::at_zero(force_thunk2),
guard: None, guard: None,
redundant: RedundantMark::new(var_store), redundant: RedundantMark::new(var_store),
@ -1268,9 +1276,13 @@ fn build_effect_loop_inner_body(
let done_branch = { let done_branch = {
let done_tag_name = TagName("Done".into()); let done_tag_name = TagName("Done".into());
let done_pattern = applied_tag_pattern(done_tag_name, &[done_symbol], var_store); let done_pattern = applied_tag_pattern(done_tag_name, &[done_symbol], var_store);
let done_pattern = WhenBranchPattern {
pattern: Loc::at_zero(done_pattern),
degenerate: false,
};
crate::expr::WhenBranch { crate::expr::WhenBranch {
patterns: vec![Loc::at_zero(done_pattern)], patterns: vec![done_pattern],
value: Loc::at_zero(Expr::Var(done_symbol)), value: Loc::at_zero(Expr::Var(done_symbol)),
guard: None, guard: None,
redundant: RedundantMark::new(var_store), redundant: RedundantMark::new(var_store),

View file

@ -260,16 +260,19 @@ pub fn sketch_when_branches(
// NB: ordering the guard pattern first seems to be better at catching // NB: ordering the guard pattern first seems to be better at catching
// non-exhaustive constructors in the second argument; see the paper to see if // non-exhaustive constructors in the second argument; see the paper to see if
// there is a way to improve this in general. // there is a way to improve this in general.
vec![guard_pattern, sketch_pattern(target_var, &loc_pat.value)], vec![
guard_pattern,
sketch_pattern(target_var, &loc_pat.pattern.value),
],
)] )]
} else { } else {
// Simple case // Simple case
vec![sketch_pattern(target_var, &loc_pat.value)] vec![sketch_pattern(target_var, &loc_pat.pattern.value)]
}; };
let row = SketchedRow { let row = SketchedRow {
patterns, patterns,
region: loc_pat.region, region: loc_pat.pattern.region,
guard, guard,
redundant_mark: *redundant, redundant_mark: *redundant,
}; };

View file

@ -481,9 +481,18 @@ impl Recursive {
} }
} }
#[derive(Clone, Debug)]
pub struct WhenBranchPattern {
pub pattern: Loc<Pattern>,
/// Degenerate branch patterns are those that don't fully bind symbols that the branch body
/// needs. For example, in `A x | B y -> x`, the `B y` pattern is degenerate.
/// Degenerate patterns emit a runtime error if reached in a program.
pub degenerate: bool,
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct WhenBranch { pub struct WhenBranch {
pub patterns: Vec<Loc<Pattern>>, pub patterns: Vec<WhenBranchPattern>,
pub value: Loc<Expr>, pub value: Loc<Expr>,
pub guard: Option<Loc<Expr>>, pub guard: Option<Loc<Expr>>,
/// Whether this branch is redundant in the `when` it appears in /// Whether this branch is redundant in the `when` it appears in
@ -497,11 +506,13 @@ impl WhenBranch {
.patterns .patterns
.first() .first()
.expect("when branch has no pattern?") .expect("when branch has no pattern?")
.pattern
.region, .region,
&self &self
.patterns .patterns
.last() .last()
.expect("when branch has no pattern?") .expect("when branch has no pattern?")
.pattern
.region, .region,
) )
} }
@ -512,7 +523,7 @@ impl WhenBranch {
Region::across_all( Region::across_all(
self.patterns self.patterns
.iter() .iter()
.map(|p| &p.region) .map(|p| &p.pattern.region)
.chain([self.value.region].iter()), .chain([self.value.region].iter()),
) )
} }
@ -1350,14 +1361,20 @@ fn canonicalize_when_branch<'a>(
); );
multi_pattern_variables.add_pattern(&can_pattern); multi_pattern_variables.add_pattern(&can_pattern);
patterns.push(can_pattern); patterns.push(WhenBranchPattern {
pattern: can_pattern,
degenerate: false,
});
} }
let mut some_symbols_not_bound_in_all_patterns = false;
for (unbound_symbol, region) in multi_pattern_variables.get_unbound() { for (unbound_symbol, region) in multi_pattern_variables.get_unbound() {
env.problem(Problem::NotBoundInAllPatterns { env.problem(Problem::NotBoundInAllPatterns {
unbound_symbol, unbound_symbol,
region, region,
}) });
some_symbols_not_bound_in_all_patterns = true;
} }
let (value, mut branch_output) = canonicalize_expr( let (value, mut branch_output) = canonicalize_expr(
@ -1384,12 +1401,33 @@ fn canonicalize_when_branch<'a>(
// Now that we've collected all the references for this branch, check to see if // Now that we've collected all the references for this branch, check to see if
// any of the new idents it defined were unused. If any were, report it. // any of the new idents it defined were unused. If any were, report it.
for (symbol, region) in BindingsFromPattern::new_many(patterns.iter()) { let mut pattern_bound_symbols_body_needs = VecSet::default();
if !output.references.has_value_lookup(symbol) { for (symbol, region) in BindingsFromPattern::new_many(patterns.iter().map(|pat| &pat.pattern)) {
if output.references.has_value_lookup(symbol) {
pattern_bound_symbols_body_needs.insert(symbol);
} else {
env.problem(Problem::UnusedDef(symbol, region)); env.problem(Problem::UnusedDef(symbol, region));
} }
} }
if some_symbols_not_bound_in_all_patterns && !pattern_bound_symbols_body_needs.is_empty() {
// There might be branches that don't bind all the symbols needed by the body; mark those
// branches degenerate.
for pattern in patterns.iter_mut() {
let bound_by_pattern: VecSet<_> = BindingsFromPattern::new(&pattern.pattern)
.map(|(sym, _)| sym)
.collect();
let binds_all_needed = pattern_bound_symbols_body_needs
.iter()
.all(|sym| bound_by_pattern.contains(sym));
if !binds_all_needed {
pattern.degenerate = true;
}
}
}
( (
WhenBranch { WhenBranch {
patterns, patterns,

View file

@ -910,7 +910,10 @@ fn fix_values_captured_in_closure_expr(
// patterns can contain default expressions, so much go over them too! // patterns can contain default expressions, so much go over them too!
for loc_pat in branch.patterns.iter_mut() { for loc_pat in branch.patterns.iter_mut() {
fix_values_captured_in_closure_pattern(&mut loc_pat.value, no_capture_symbols); fix_values_captured_in_closure_pattern(
&mut loc_pat.pattern.value,
no_capture_symbols,
);
} }
if let Some(guard) = &mut branch.guard { if let Some(guard) = &mut branch.guard {

View file

@ -323,9 +323,13 @@ pub fn walk_when_branch<V: Visitor>(
redundant: _, redundant: _,
} = branch; } = branch;
patterns patterns.iter().for_each(|pat| {
.iter() visitor.visit_pattern(
.for_each(|pat| visitor.visit_pattern(&pat.value, pat.region, pat.value.opt_var())); &pat.pattern.value,
pat.pattern.region,
pat.pattern.value.opt_var(),
)
});
visitor.visit_expr(&value.value, value.region, expr_var); visitor.visit_expr(&value.value, value.region, expr_var);
if let Some(guard) = guard { if let Some(guard) = guard {
visitor.visit_expr(&guard.value, guard.region, Variable::BOOL); visitor.visit_expr(&guard.value, guard.region, Variable::BOOL);

View file

@ -1837,14 +1837,15 @@ fn constrain_when_branch_help(
}; };
for (i, loc_pattern) in when_branch.patterns.iter().enumerate() { for (i, loc_pattern) in when_branch.patterns.iter().enumerate() {
let pattern_expected = pattern_expected(HumanIndex::zero_based(i), loc_pattern.region); let pattern_expected =
pattern_expected(HumanIndex::zero_based(i), loc_pattern.pattern.region);
let mut partial_state = PatternState::default(); let mut partial_state = PatternState::default();
constrain_pattern( constrain_pattern(
constraints, constraints,
env, env,
&loc_pattern.value, &loc_pattern.pattern.value,
loc_pattern.region, loc_pattern.pattern.region,
pattern_expected, pattern_expected,
&mut partial_state, &mut partial_state,
); );

View file

@ -3,7 +3,9 @@
use std::iter::once; use std::iter::once;
use roc_can::abilities::SpecializationLambdaSets; use roc_can::abilities::SpecializationLambdaSets;
use roc_can::expr::{AnnotatedMark, ClosureData, Expr, Field, Recursive, WhenBranch}; use roc_can::expr::{
AnnotatedMark, ClosureData, Expr, Field, Recursive, WhenBranch, WhenBranchPattern,
};
use roc_can::module::ExposedByModule; use roc_can::module::ExposedByModule;
use roc_can::pattern::Pattern; use roc_can::pattern::Pattern;
use roc_collections::SendMap; use roc_collections::SendMap;
@ -672,6 +674,10 @@ fn to_encoder_tag_union(
.map(|(var, sym)| (*var, Loc::at_zero(Pattern::Identifier(*sym)))) .map(|(var, sym)| (*var, Loc::at_zero(Pattern::Identifier(*sym))))
.collect(), .collect(),
}; };
let branch_pattern = WhenBranchPattern {
pattern: Loc::at_zero(pattern),
degenerate: false,
};
// whole type of the elements in [ Encode.toEncoder v1, Encode.toEncoder v2 ] // whole type of the elements in [ Encode.toEncoder v1, Encode.toEncoder v2 ]
let whole_payload_encoders_var = env.subs.fresh_unnamed_flex_var(); let whole_payload_encoders_var = env.subs.fresh_unnamed_flex_var();
@ -792,7 +798,7 @@ fn to_encoder_tag_union(
env.unify(this_encoder_var, whole_tag_encoders_var); env.unify(this_encoder_var, whole_tag_encoders_var);
WhenBranch { WhenBranch {
patterns: vec![Loc::at_zero(pattern)], patterns: vec![branch_pattern],
value: Loc::at_zero(encode_tag_call), value: Loc::at_zero(encode_tag_call),
guard: None, guard: None,
redundant: RedundantMark::known_non_redundant(), redundant: RedundantMark::known_non_redundant(),

View file

@ -20,7 +20,7 @@ use roc_debug_flags::{
}; };
use roc_derive::SharedDerivedModule; use roc_derive::SharedDerivedModule;
use roc_error_macros::{internal_error, todo_abilities}; use roc_error_macros::{internal_error, todo_abilities};
use roc_exhaustive::{Ctor, CtorName, Guard, RenderAs, TagId}; use roc_exhaustive::{Ctor, CtorName, RenderAs, TagId};
use roc_late_solve::{resolve_ability_specialization, AbilitiesView, Resolved, UnificationFailed}; use roc_late_solve::{resolve_ability_specialization, AbilitiesView, Resolved, UnificationFailed};
use roc_module::ident::{ForeignSymbol, Lowercase, TagName}; use roc_module::ident::{ForeignSymbol, Lowercase, TagName};
use roc_module::low_level::LowLevel; use roc_module::low_level::LowLevel;
@ -2533,7 +2533,7 @@ fn pattern_to_when<'a>(
body: Loc<roc_can::expr::Expr>, body: Loc<roc_can::expr::Expr>,
) -> (Symbol, Loc<roc_can::expr::Expr>) { ) -> (Symbol, Loc<roc_can::expr::Expr>) {
use roc_can::expr::Expr::*; use roc_can::expr::Expr::*;
use roc_can::expr::WhenBranch; use roc_can::expr::{WhenBranch, WhenBranchPattern};
use roc_can::pattern::Pattern::*; use roc_can::pattern::Pattern::*;
match &pattern.value { match &pattern.value {
@ -2580,7 +2580,10 @@ fn pattern_to_when<'a>(
region: Region::zero(), region: Region::zero(),
loc_cond: Box::new(Loc::at_zero(Var(symbol))), loc_cond: Box::new(Loc::at_zero(Var(symbol))),
branches: vec![WhenBranch { branches: vec![WhenBranch {
patterns: vec![pattern], patterns: vec![WhenBranchPattern {
pattern,
degenerate: false,
}],
value: body, value: body,
guard: None, guard: None,
// If this type-checked, it's non-redundant // If this type-checked, it's non-redundant
@ -5184,7 +5187,7 @@ pub fn with_hole<'a>(
} }
} }
TypedHole(_) => Stmt::RuntimeError("Hit a blank"), TypedHole(_) => Stmt::RuntimeError("Hit a blank"),
RuntimeError(e) => Stmt::RuntimeError(env.arena.alloc(format!("{:?}", e))), RuntimeError(e) => Stmt::RuntimeError(env.arena.alloc(e.runtime_message())),
} }
} }
@ -6033,56 +6036,50 @@ fn to_opt_branches<'a>(
)> { )> {
debug_assert!(!branches.is_empty()); debug_assert!(!branches.is_empty());
let mut loc_branches = std::vec::Vec::new();
let mut opt_branches = std::vec::Vec::new(); let mut opt_branches = std::vec::Vec::new();
for when_branch in branches { for when_branch in branches {
let exhaustive_guard = if when_branch.guard.is_some() {
Guard::HasGuard
} else {
Guard::NoGuard
};
if when_branch.redundant.is_redundant(env.subs) { if when_branch.redundant.is_redundant(env.subs) {
// Don't codegen this branch since it's redundant. // Don't codegen this branch since it's redundant.
continue; continue;
} }
for loc_pattern in when_branch.patterns { for loc_pattern in when_branch.patterns {
match from_can_pattern(env, procs, layout_cache, &loc_pattern.value) { match from_can_pattern(env, procs, layout_cache, &loc_pattern.pattern.value) {
Ok((mono_pattern, assignments)) => { Ok((mono_pattern, assignments)) => {
loc_branches.push(( let loc_expr = if !loc_pattern.degenerate {
Loc::at(loc_pattern.region, mono_pattern.clone()), let mut loc_expr = when_branch.value.clone();
exhaustive_guard,
));
let mut loc_expr = when_branch.value.clone(); let region = loc_pattern.pattern.region;
let region = loc_pattern.region; for (symbol, variable, expr) in assignments.into_iter().rev() {
for (symbol, variable, expr) in assignments.into_iter().rev() { let def = roc_can::def::Def {
let def = roc_can::def::Def { annotation: None,
annotation: None, expr_var: variable,
expr_var: variable, loc_expr: Loc::at(region, expr),
loc_expr: Loc::at(region, expr), loc_pattern: Loc::at(
loc_pattern: Loc::at( region,
region, roc_can::pattern::Pattern::Identifier(symbol),
roc_can::pattern::Pattern::Identifier(symbol), ),
), pattern_vars: std::iter::once((symbol, variable)).collect(),
pattern_vars: std::iter::once((symbol, variable)).collect(), };
}; let new_expr =
let new_expr = roc_can::expr::Expr::LetNonRec(Box::new(def), Box::new(loc_expr));
roc_can::expr::Expr::LetNonRec(Box::new(def), Box::new(loc_expr)); loc_expr = Loc::at(region, new_expr);
loc_expr = Loc::at(region, new_expr); }
}
loc_expr
} else {
// This pattern is degenerate; when it's reached we must emit a runtime
// error.
Loc::at_zero(roc_can::expr::Expr::RuntimeError(
RuntimeError::DegenerateBranch(loc_pattern.pattern.region),
))
};
// TODO remove clone? // TODO remove clone?
opt_branches.push((mono_pattern, when_branch.guard.clone(), loc_expr.value)); opt_branches.push((mono_pattern, when_branch.guard.clone(), loc_expr.value));
} }
Err(runtime_error) => { Err(runtime_error) => {
loc_branches.push((
Loc::at(loc_pattern.region, Pattern::Underscore),
exhaustive_guard,
));
// TODO remove clone? // TODO remove clone?
opt_branches.push(( opt_branches.push((
Pattern::Underscore, Pattern::Underscore,

View file

@ -329,6 +329,24 @@ pub enum RuntimeError {
EmptySingleQuote(Region), EmptySingleQuote(Region),
/// where 'aa' /// where 'aa'
MultipleCharsInSingleQuote(Region), MultipleCharsInSingleQuote(Region),
DegenerateBranch(Region),
}
impl RuntimeError {
pub fn runtime_message(self) -> String {
use RuntimeError::*;
match self {
DegenerateBranch(region) => {
format!(
"Hit a branch pattern that does not bind all symbols its body needs, at {:?}",
region
)
}
err => format!("{:?}", err),
}
}
} }
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]

View file

@ -221,7 +221,7 @@ fn branch<'a>(c: &Ctx, f: &'a Arena<'a>, b: &'a WhenBranch) -> DocBuilder<'a, Ar
f.intersperse( f.intersperse(
patterns patterns
.iter() .iter()
.map(|lp| pattern(c, PPrec::Free, f, &lp.value)), .map(|lp| pattern(c, PPrec::Free, f, &lp.pattern.value)),
f.text(" | "), f.text(" | "),
) )
.append(match guard { .append(match guard {

View file

@ -3651,7 +3651,25 @@ fn shared_pattern_variable_in_when_branches() {
} }
#[test] #[test]
#[ignore = "TODO currently fails in alias analysis because `B y` does not introduce `x`"] #[cfg(any(feature = "gen-llvm"))]
fn symbol_not_bound_in_all_patterns_runs_when_no_bound_symbol_used() {
assert_evals_to!(
indoc!(
r#"
f = \t -> when t is
A x | B y -> 31u8
{a: f (A 15u8), b: f (B 15u8)}
"#
),
31u8,
u8,
|x| x,
true // allow errors
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn symbol_not_bound_in_all_patterns_runs_when_bound_pattern_reached() { fn symbol_not_bound_in_all_patterns_runs_when_bound_pattern_reached() {
assert_evals_to!( assert_evals_to!(
@ -3662,6 +3680,28 @@ fn symbol_not_bound_in_all_patterns_runs_when_bound_pattern_reached() {
"# "#
), ),
15u8, 15u8,
u8 u8,
|x| x,
true // allow errors
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
#[should_panic(
expected = r#"Roc failed with message: "Hit a branch pattern that does not bind all symbols its body needs"#
)]
fn runtime_error_when_degenerate_pattern_reached() {
assert_evals_to!(
indoc!(
r#"
when B 15u8 is
A x | B y -> x + 5u8
"#
),
15u8,
u8,
|x| x,
true // allow errors
); );
} }

View file

@ -272,7 +272,12 @@ fn create_llvm_module<'a>(
// Uncomment this to see the module's optimized LLVM instruction output: // Uncomment this to see the module's optimized LLVM instruction output:
// env.module.print_to_stderr(); // env.module.print_to_stderr();
(main_fn_name, delayed_errors.join("\n"), env.module) let delayed_errors = if config.ignore_problems {
String::new()
} else {
delayed_errors.join("\n")
};
(main_fn_name, delayed_errors, env.module)
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]

View file

@ -1899,6 +1899,14 @@ fn pretty_runtime_error<'b>(
title = OPAQUE_OVER_APPLIED; title = OPAQUE_OVER_APPLIED;
} }
RuntimeError::DegenerateBranch(region) => {
doc = alloc.stack([
alloc.reflow("This branch pattern does not bind all symbols its body needs:"),
alloc.region(lines.convert_region(region)),
]);
title = "DEGENERATE BRANCH";
}
} }
(doc, title) (doc, title)