Use an iterator to walk over pattern bindings

This commit is contained in:
Ayaz Hafiz 2023-03-25 17:03:34 -05:00
parent f75248d206
commit 93dc3714de
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
2 changed files with 144 additions and 57 deletions

View file

@ -10,6 +10,7 @@ use crate::layout::{
};
use roc_builtins::bitcode::{FloatWidth, IntWidth};
use roc_collections::all::{MutMap, MutSet};
use roc_collections::BumpMap;
use roc_error_macros::internal_error;
use roc_exhaustive::{Ctor, CtorName, ListArity, RenderAs, TagId, Union};
use roc_module::ident::TagName;
@ -1474,12 +1475,12 @@ pub(crate) fn optimize_when<'a>(
// updating.
let pattern_bindings = pattern.collect_symbols(cond_layout);
let mut parameters_buf =
bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena);
let mut parameters_buf = bumpalo::collections::Vec::with_capacity_in(1, env.arena);
let mut pattern_symbols_buf =
bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena);
bumpalo::collections::Vec::with_capacity_in(1, env.arena);
let mut substitutions = BumpMap::default();
for &(pattern_symbol, layout) in pattern_bindings.iter() {
for (pattern_symbol, layout) in pattern_bindings {
let param_symbol = env.unique_symbol();
parameters_buf.push(Param {
symbol: param_symbol,
@ -1487,16 +1488,12 @@ pub(crate) fn optimize_when<'a>(
ownership: Ownership::Owned,
});
pattern_symbols_buf.push(pattern_symbol);
substitutions.insert(pattern_symbol, param_symbol);
}
join_params = parameters_buf.into_bump_slice();
jump_pattern_param_symbols = pattern_symbols_buf.into_bump_slice();
let substitutions = pattern_bindings
.iter()
.zip(join_params.iter())
.map(|((pat, _), param)| (*pat, param.symbol))
.collect();
substitute_in_exprs_many(env.arena, &mut branch, substitutions);
}
}

View file

@ -113,60 +113,150 @@ impl<'a> Pattern<'a> {
false
}
// TODO: vast majority of the time, the patterns will be singleton or empty.
// We should introduce a smallvec optimized for the singleton case.
pub fn collect_symbols(&self, layout: InLayout<'a>) -> std::vec::Vec<(Symbol, InLayout<'a>)> {
let mut stack = vec![(self, layout)];
let mut collected = std::vec::Vec::with_capacity(1);
pub fn collect_symbols(
&self,
layout: InLayout<'a>,
) -> impl Iterator<Item = (Symbol, InLayout<'a>)> + '_ {
PatternBindingIter::One(self, layout)
}
}
while let Some((pattern, layout)) = stack.pop() {
match pattern {
Pattern::Identifier(symbol) => {
collected.push((*symbol, layout));
}
Pattern::Underscore => {}
Pattern::As(subpattern, symbol) => {
collected.push((*symbol, layout));
stack.push((subpattern, layout));
}
Pattern::IntLiteral(_, _)
| Pattern::FloatLiteral(_, _)
| Pattern::DecimalLiteral(_)
| Pattern::BitLiteral { .. }
| Pattern::EnumLiteral { .. }
| Pattern::StrLiteral(_) => {}
Pattern::RecordDestructure(destructs, _) => {
for destruct in destructs {
match &destruct.typ {
DestructType::Required(symbol) => {
collected.push((*symbol, destruct.layout));
}
DestructType::Guard(pattern) => {
stack.push((pattern, destruct.layout));
}
}
enum PatternBindingIter<'r, 'a> {
Done,
One(&'r Pattern<'a>, InLayout<'a>),
Stack(std::vec::Vec<(PatternBindingWork<'r, 'a>, InLayout<'a>)>),
}
enum PatternBindingWork<'r, 'a> {
Pat(&'r Pattern<'a>),
RecordDestruct(&'r DestructType<'a>),
}
impl<'r, 'a> Iterator for PatternBindingIter<'r, 'a> {
type Item = (Symbol, InLayout<'a>);
fn next(&mut self) -> Option<Self::Item> {
use Pattern::*;
use PatternBindingIter::*;
use PatternBindingWork::*;
match self {
Done => None,
One(pattern, layout) => {
let layout = *layout;
match pattern {
Identifier(symbol) => {
*self = Done;
(*symbol, layout).into()
}
}
Pattern::TupleDestructure(destructs, _) => {
for destruct in destructs {
stack.push((&destruct.pat, destruct.layout));
Underscore => None,
As(pat, symbol) => {
*self = One(&**pat, layout);
(*symbol, layout).into()
}
}
Pattern::NewtypeDestructure { arguments, .. } => {
stack.extend(arguments.iter().map(|(t, l)| (t, *l)))
}
Pattern::Voided { .. } => {}
Pattern::AppliedTag { arguments, .. } => {
stack.extend(arguments.iter().map(|(t, l)| (t, *l)))
}
Pattern::OpaqueUnwrap { argument, .. } => stack.push((&argument.0, argument.1)),
Pattern::List { elements, .. } => {
stack.extend(elements.iter().map(|t| (t, layout)))
RecordDestructure(destructs, _) => {
let stack = destructs
.iter()
.map(|destruct| (RecordDestruct(&destruct.typ), destruct.layout))
.rev()
.collect();
*self = Stack(stack);
self.next()
}
TupleDestructure(destructs, _) => {
let stack = destructs
.iter()
.map(|destruct| (Pat(&destruct.pat), destruct.layout))
.rev()
.collect();
*self = Stack(stack);
self.next()
}
NewtypeDestructure { arguments, .. } | AppliedTag { arguments, .. } => {
let stack = arguments.iter().map(|(p, l)| (Pat(p), *l)).rev().collect();
*self = Stack(stack);
self.next()
}
OpaqueUnwrap { argument, .. } => {
*self = One(&argument.0, layout);
self.next()
}
List {
element_layout,
elements,
..
} => {
let stack = elements
.iter()
.map(|p| (Pat(p), *element_layout))
.rev()
.collect();
*self = Stack(stack);
self.next()
}
IntLiteral(_, _)
| FloatLiteral(_, _)
| DecimalLiteral(_)
| BitLiteral { .. }
| EnumLiteral { .. }
| StrLiteral(_)
| Voided { .. } => None,
}
}
}
Stack(stack) => {
while let Some((pat, layout)) = stack.pop() {
match pat {
Pat(pattern) => match pattern {
Identifier(symbol) => return (*symbol, layout).into(),
As(pat, symbol) => {
stack.push((Pat(pat), layout));
return (*symbol, layout).into();
}
RecordDestructure(destructs, _) => stack.extend(
destructs
.iter()
.map(|destruct| {
(RecordDestruct(&destruct.typ), destruct.layout)
})
.rev(),
),
TupleDestructure(destructs, _) => stack.extend(
destructs
.iter()
.map(|destruct| (Pat(&destruct.pat), destruct.layout))
.rev(),
),
NewtypeDestructure { arguments, .. } | AppliedTag { arguments, .. } => {
stack.extend(arguments.iter().map(|(p, l)| (Pat(p), *l)).rev())
}
OpaqueUnwrap { argument, .. } => {
stack.push((Pat(&argument.0), layout));
}
List {
element_layout,
elements,
..
} => {
stack.extend(
elements.iter().map(|p| (Pat(p), *element_layout)).rev(),
);
}
IntLiteral(_, _)
| FloatLiteral(_, _)
| DecimalLiteral(_)
| BitLiteral { .. }
| EnumLiteral { .. }
| Underscore
| StrLiteral(_)
| Voided { .. } => {}
},
PatternBindingWork::RecordDestruct(_) => todo!(),
}
}
collected
*self = Done;
None
}
}
}
}