Do not revisit variables in an occurs check

Turns out this mark cache check is unreasonably effective, even if it
is naive.
This commit is contained in:
Ayaz Hafiz 2023-04-06 14:42:31 -05:00
parent d7528c528b
commit a816f8bc83
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58

View file

@ -25,7 +25,8 @@ roc_error_macros::assert_sizeof_all!(RecordFields, 2 * 8);
pub struct Mark(i32); pub struct Mark(i32);
impl Mark { impl Mark {
pub const NONE: Mark = Mark(2); pub const NONE: Mark = Mark(3);
pub const VISITED_IN_OCCURS_CHECK: Mark = Mark(2);
pub const OCCURS: Mark = Mark(1); pub const OCCURS: Mark = Mark(1);
pub const GET_VAR_NAMES: Mark = Mark(0); pub const GET_VAR_NAMES: Mark = Mark(0);
@ -1996,9 +1997,15 @@ impl Subs {
/// ///
/// This ignores [Content::RecursionVar]s that occur recursively, because those are /// This ignores [Content::RecursionVar]s that occur recursively, because those are
/// already priced in and expected to occur. /// already priced in and expected to occur.
pub fn occurs(&self, var: Variable) -> Result<(), (Variable, Vec<Variable>)> { ///
/// Although `subs` is taken as mutable reference, this function will return it in the same
/// state it was given.
pub fn occurs(&mut self, var: Variable) -> Result<(), (Variable, Vec<Variable>)> {
let mut scratchpad = take_occurs_scratchpad(); let mut scratchpad = take_occurs_scratchpad();
let result = occurs(self, &mut scratchpad, var); let result = occurs(self, &mut scratchpad, var);
for v in &scratchpad.all_visited {
self.set_mark_unchecked(*v, Mark::NONE);
}
put_occurs_scratchpad(scratchpad); put_occurs_scratchpad(scratchpad);
result result
} }
@ -3434,15 +3441,34 @@ impl TupleElems {
} }
} }
std::thread_local! { struct OccursScratchpad {
static SCRATCHPAD_FOR_OCCURS: RefCell<Option<Vec<Variable>>> = RefCell::new(Some(Vec::with_capacity(1024))); seen: Vec<Variable>,
all_visited: Vec<Variable>,
} }
fn take_occurs_scratchpad() -> Vec<Variable> { impl OccursScratchpad {
fn new_static() -> Self {
Self {
seen: Vec::with_capacity(1024),
all_visited: Vec::with_capacity(1024),
}
}
fn clear(&mut self) {
self.seen.clear();
self.all_visited.clear();
}
}
std::thread_local! {
static SCRATCHPAD_FOR_OCCURS: RefCell<Option<OccursScratchpad>> = RefCell::new(Some(OccursScratchpad::new_static()));
}
fn take_occurs_scratchpad() -> OccursScratchpad {
SCRATCHPAD_FOR_OCCURS.with(|f| f.take().unwrap()) SCRATCHPAD_FOR_OCCURS.with(|f| f.take().unwrap())
} }
fn put_occurs_scratchpad(mut scratchpad: Vec<Variable>) { fn put_occurs_scratchpad(mut scratchpad: OccursScratchpad) {
SCRATCHPAD_FOR_OCCURS.with(|f| { SCRATCHPAD_FOR_OCCURS.with(|f| {
scratchpad.clear(); scratchpad.clear();
f.replace(Some(scratchpad)); f.replace(Some(scratchpad));
@ -3450,19 +3476,36 @@ fn put_occurs_scratchpad(mut scratchpad: Vec<Variable>) {
} }
fn occurs( fn occurs(
subs: &Subs, subs: &mut Subs,
seen: &mut Vec<Variable>, ctx: &mut OccursScratchpad,
input_var: Variable, input_var: Variable,
) -> Result<(), (Variable, Vec<Variable>)> { ) -> Result<(), (Variable, Vec<Variable>)> {
// NB(subs-invariant): it is pivotal that subs is not modified in any material way.
// As variables are visited, they are marked as observed so they are not revisited,
// but no other modification should take place.
use self::Content::*; use self::Content::*;
use self::FlatType::*; use self::FlatType::*;
let root_var = subs.get_root_key_without_compacting(input_var); let root_var = subs.get_root_key_without_compacting(input_var);
if seen.contains(&root_var) { // SAFETY: due to XREF(subs-invariant), only the mark in a variable is modified, and all
// variable (and other content) identities are guaranteed to be preserved during an occurs
// check. As a result, we can freely take references of variables and UnionTags.
macro_rules! safe {
($t:ty, $expr:expr) => {
unsafe { std::mem::transmute::<_, &'static $t>($expr) }
};
}
if ctx.seen.contains(&root_var) {
Err((root_var, Vec::with_capacity(0))) Err((root_var, Vec::with_capacity(0)))
} else if subs.get_mark_unchecked(root_var) == Mark::VISITED_IN_OCCURS_CHECK {
Ok(())
} else { } else {
seen.push(root_var); ctx.all_visited.push(root_var);
subs.set_mark_unchecked(root_var, Mark::VISITED_IN_OCCURS_CHECK);
ctx.seen.push(root_var);
let result = (|| match subs.get_content_without_compacting(root_var) { let result = (|| match subs.get_content_without_compacting(root_var) {
FlexVar(_) FlexVar(_)
| RigidVar(_) | RigidVar(_)
@ -3472,47 +3515,57 @@ fn occurs(
| Error => Ok(()), | Error => Ok(()),
Structure(flat_type) => match flat_type { Structure(flat_type) => match flat_type {
Apply(_, args) => { Apply(_, args) => short_circuit(
short_circuit(subs, root_var, seen, subs.get_subs_slice(*args).iter()) subs,
} root_var,
ctx,
safe!([Variable], subs.get_subs_slice(*args)).iter(),
),
Func(arg_vars, closure_var, ret_var) => { Func(arg_vars, closure_var, ret_var) => {
let it = once(ret_var) let it = once(safe!(Variable, ret_var))
.chain(once(closure_var)) .chain(once(safe!(Variable, closure_var)))
.chain(subs.get_subs_slice(*arg_vars).iter()); .chain(safe!([Variable], subs.get_subs_slice(*arg_vars)).iter());
short_circuit(subs, root_var, seen, it) short_circuit(subs, root_var, ctx, it)
} }
Record(vars_by_field, ext) => { Record(vars_by_field, ext) => {
let slice = SubsSlice::new(vars_by_field.variables_start, vars_by_field.length); let slice =
let it = once(ext).chain(subs.get_subs_slice(slice).iter()); VariableSubsSlice::new(vars_by_field.variables_start, vars_by_field.length);
short_circuit(subs, root_var, seen, it) let it = once(safe!(Variable, ext))
.chain(safe!([Variable], subs.get_subs_slice(slice)).iter());
short_circuit(subs, root_var, ctx, it)
} }
Tuple(vars_by_elem, ext) => { Tuple(vars_by_elem, ext) => {
let slice = SubsSlice::new(vars_by_elem.variables_start, vars_by_elem.length); let slice =
let it = once(ext).chain(subs.get_subs_slice(slice).iter()); VariableSubsSlice::new(vars_by_elem.variables_start, vars_by_elem.length);
short_circuit(subs, root_var, seen, it) let it = once(safe!(Variable, ext))
.chain(safe!([Variable], subs.get_subs_slice(slice)).iter());
short_circuit(subs, root_var, ctx, it)
} }
TagUnion(tags, ext) => { TagUnion(tags, ext) => {
occurs_union(subs, root_var, seen, tags)?; let ext_var = ext.var();
occurs_union(subs, root_var, ctx, safe!(UnionLabels<TagName>, tags))?;
short_circuit_help(subs, root_var, seen, ext.var()) short_circuit_help(subs, root_var, ctx, ext_var)
} }
FunctionOrTagUnion(_, _, ext) => { FunctionOrTagUnion(_, _, ext) => {
short_circuit(subs, root_var, seen, once(&ext.var())) short_circuit(subs, root_var, ctx, once(&ext.var()))
} }
RecursiveTagUnion(_, tags, ext) => { RecursiveTagUnion(_, tags, ext) => {
occurs_union(subs, root_var, seen, tags)?; let ext_var = ext.var();
occurs_union(subs, root_var, ctx, safe!(UnionLabels<TagName>, tags))?;
short_circuit_help(subs, root_var, seen, ext.var()) short_circuit_help(subs, root_var, ctx, ext_var)
} }
EmptyRecord | EmptyTuple | EmptyTagUnion => Ok(()), EmptyRecord | EmptyTuple | EmptyTagUnion => Ok(()),
}, },
Alias(_, args, real_var, _) => { Alias(_, args, real_var, _) => {
let real_var = *real_var;
for var_index in args.into_iter() { for var_index in args.into_iter() {
let var = subs[var_index]; let var = subs[var_index];
if short_circuit_help(subs, root_var, seen, var).is_err() { if short_circuit_help(subs, root_var, ctx, var).is_err() {
// Pay the cost and figure out what the actual recursion point is // Pay the cost and figure out what the actual recursion point is
return short_circuit_help(subs, root_var, seen, *real_var); return short_circuit_help(subs, root_var, ctx, real_var);
} }
} }
@ -3527,27 +3580,27 @@ fn occurs(
// unspecialized lambda vars excluded because they are not explicitly part of the // unspecialized lambda vars excluded because they are not explicitly part of the
// type (they only matter after being resolved). // type (they only matter after being resolved).
occurs_union(subs, root_var, seen, solved) occurs_union(subs, root_var, ctx, safe!(UnionLabels<Symbol>, solved))
} }
RangedNumber(_range_vars) => Ok(()), RangedNumber(_range_vars) => Ok(()),
})(); })();
seen.pop(); ctx.seen.pop();
result result
} }
} }
#[inline(always)] #[inline(always)]
fn occurs_union<L: Label>( fn occurs_union<L: Label>(
subs: &Subs, subs: &mut Subs,
root_var: Variable, root_var: Variable,
seen: &mut Vec<Variable>, ctx: &mut OccursScratchpad,
tags: &UnionLabels<L>, tags: &UnionLabels<L>,
) -> Result<(), (Variable, Vec<Variable>)> { ) -> Result<(), (Variable, Vec<Variable>)> {
for slice_index in tags.variables() { for slice_index in tags.variables() {
let slice = subs[slice_index]; let slice = subs[slice_index];
for var_index in slice { for var_index in slice {
let var = subs[var_index]; let var = subs[var_index];
short_circuit_help(subs, root_var, seen, var)?; short_circuit_help(subs, root_var, ctx, var)?;
} }
} }
Ok(()) Ok(())
@ -3555,16 +3608,16 @@ fn occurs_union<L: Label>(
#[inline(always)] #[inline(always)]
fn short_circuit<'a, T>( fn short_circuit<'a, T>(
subs: &Subs, subs: &mut Subs,
root_key: Variable, root_key: Variable,
seen: &mut Vec<Variable>, ctx: &mut OccursScratchpad,
iter: T, iter: T,
) -> Result<(), (Variable, Vec<Variable>)> ) -> Result<(), (Variable, Vec<Variable>)>
where where
T: Iterator<Item = &'a Variable>, T: Iterator<Item = &'a Variable>,
{ {
for var in iter { for var in iter {
short_circuit_help(subs, root_key, seen, *var)?; short_circuit_help(subs, root_key, ctx, *var)?;
} }
Ok(()) Ok(())
@ -3572,12 +3625,12 @@ where
#[inline(always)] #[inline(always)]
fn short_circuit_help( fn short_circuit_help(
subs: &Subs, subs: &mut Subs,
root_key: Variable, root_key: Variable,
seen: &mut Vec<Variable>, ctx: &mut OccursScratchpad,
var: Variable, var: Variable,
) -> Result<(), (Variable, Vec<Variable>)> { ) -> Result<(), (Variable, Vec<Variable>)> {
if let Err((v, mut vec)) = occurs(subs, seen, var) { if let Err((v, mut vec)) = occurs(subs, ctx, var) {
vec.push(root_key); vec.push(root_key);
return Err((v, vec)); return Err((v, vec));
} }