From a816f8bc83ceb59f367a9e77d611c1df2d690d77 Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Thu, 6 Apr 2023 14:42:31 -0500 Subject: [PATCH] Do not revisit variables in an occurs check Turns out this mark cache check is unreasonably effective, even if it is naive. --- crates/compiler/types/src/subs.rs | 135 +++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 41 deletions(-) diff --git a/crates/compiler/types/src/subs.rs b/crates/compiler/types/src/subs.rs index df259af3cf..22cc3009ac 100644 --- a/crates/compiler/types/src/subs.rs +++ b/crates/compiler/types/src/subs.rs @@ -25,7 +25,8 @@ roc_error_macros::assert_sizeof_all!(RecordFields, 2 * 8); pub struct Mark(i32); impl Mark { - pub const NONE: Mark = Mark(2); + pub const NONE: Mark = Mark(3); + pub const VISITED_IN_OCCURS_CHECK: Mark = Mark(2); pub const OCCURS: Mark = Mark(1); pub const GET_VAR_NAMES: Mark = Mark(0); @@ -1996,9 +1997,15 @@ impl Subs { /// /// This ignores [Content::RecursionVar]s that occur recursively, because those are /// already priced in and expected to occur. - pub fn occurs(&self, var: Variable) -> Result<(), (Variable, Vec)> { + /// + /// Although `subs` is taken as mutable reference, this function will return it in the same + /// state it was given. + pub fn occurs(&mut self, var: Variable) -> Result<(), (Variable, Vec)> { let mut scratchpad = take_occurs_scratchpad(); let result = occurs(self, &mut scratchpad, var); + for v in &scratchpad.all_visited { + self.set_mark_unchecked(*v, Mark::NONE); + } put_occurs_scratchpad(scratchpad); result } @@ -3434,15 +3441,34 @@ impl TupleElems { } } -std::thread_local! { - static SCRATCHPAD_FOR_OCCURS: RefCell>> = RefCell::new(Some(Vec::with_capacity(1024))); +struct OccursScratchpad { + seen: Vec, + all_visited: Vec, } -fn take_occurs_scratchpad() -> Vec { +impl OccursScratchpad { + fn new_static() -> Self { + Self { + seen: Vec::with_capacity(1024), + all_visited: Vec::with_capacity(1024), + } + } + + fn clear(&mut self) { + self.seen.clear(); + self.all_visited.clear(); + } +} + +std::thread_local! { + static SCRATCHPAD_FOR_OCCURS: RefCell> = RefCell::new(Some(OccursScratchpad::new_static())); +} + +fn take_occurs_scratchpad() -> OccursScratchpad { SCRATCHPAD_FOR_OCCURS.with(|f| f.take().unwrap()) } -fn put_occurs_scratchpad(mut scratchpad: Vec) { +fn put_occurs_scratchpad(mut scratchpad: OccursScratchpad) { SCRATCHPAD_FOR_OCCURS.with(|f| { scratchpad.clear(); f.replace(Some(scratchpad)); @@ -3450,19 +3476,36 @@ fn put_occurs_scratchpad(mut scratchpad: Vec) { } fn occurs( - subs: &Subs, - seen: &mut Vec, + subs: &mut Subs, + ctx: &mut OccursScratchpad, input_var: Variable, ) -> Result<(), (Variable, Vec)> { + // NB(subs-invariant): it is pivotal that subs is not modified in any material way. + // As variables are visited, they are marked as observed so they are not revisited, + // but no other modification should take place. + use self::Content::*; use self::FlatType::*; let root_var = subs.get_root_key_without_compacting(input_var); - if seen.contains(&root_var) { + // SAFETY: due to XREF(subs-invariant), only the mark in a variable is modified, and all + // variable (and other content) identities are guaranteed to be preserved during an occurs + // check. As a result, we can freely take references of variables and UnionTags. + macro_rules! safe { + ($t:ty, $expr:expr) => { + unsafe { std::mem::transmute::<_, &'static $t>($expr) } + }; + } + + if ctx.seen.contains(&root_var) { Err((root_var, Vec::with_capacity(0))) + } else if subs.get_mark_unchecked(root_var) == Mark::VISITED_IN_OCCURS_CHECK { + Ok(()) } else { - seen.push(root_var); + ctx.all_visited.push(root_var); + subs.set_mark_unchecked(root_var, Mark::VISITED_IN_OCCURS_CHECK); + ctx.seen.push(root_var); let result = (|| match subs.get_content_without_compacting(root_var) { FlexVar(_) | RigidVar(_) @@ -3472,47 +3515,57 @@ fn occurs( | Error => Ok(()), Structure(flat_type) => match flat_type { - Apply(_, args) => { - short_circuit(subs, root_var, seen, subs.get_subs_slice(*args).iter()) - } + Apply(_, args) => short_circuit( + subs, + root_var, + ctx, + safe!([Variable], subs.get_subs_slice(*args)).iter(), + ), Func(arg_vars, closure_var, ret_var) => { - let it = once(ret_var) - .chain(once(closure_var)) - .chain(subs.get_subs_slice(*arg_vars).iter()); - short_circuit(subs, root_var, seen, it) + let it = once(safe!(Variable, ret_var)) + .chain(once(safe!(Variable, closure_var))) + .chain(safe!([Variable], subs.get_subs_slice(*arg_vars)).iter()); + short_circuit(subs, root_var, ctx, it) } Record(vars_by_field, ext) => { - let slice = SubsSlice::new(vars_by_field.variables_start, vars_by_field.length); - let it = once(ext).chain(subs.get_subs_slice(slice).iter()); - short_circuit(subs, root_var, seen, it) + let slice = + VariableSubsSlice::new(vars_by_field.variables_start, vars_by_field.length); + let it = once(safe!(Variable, ext)) + .chain(safe!([Variable], subs.get_subs_slice(slice)).iter()); + short_circuit(subs, root_var, ctx, it) } Tuple(vars_by_elem, ext) => { - let slice = SubsSlice::new(vars_by_elem.variables_start, vars_by_elem.length); - let it = once(ext).chain(subs.get_subs_slice(slice).iter()); - short_circuit(subs, root_var, seen, it) + let slice = + VariableSubsSlice::new(vars_by_elem.variables_start, vars_by_elem.length); + let it = once(safe!(Variable, ext)) + .chain(safe!([Variable], subs.get_subs_slice(slice)).iter()); + short_circuit(subs, root_var, ctx, it) } TagUnion(tags, ext) => { - occurs_union(subs, root_var, seen, tags)?; + let ext_var = ext.var(); + occurs_union(subs, root_var, ctx, safe!(UnionLabels, tags))?; - short_circuit_help(subs, root_var, seen, ext.var()) + short_circuit_help(subs, root_var, ctx, ext_var) } FunctionOrTagUnion(_, _, ext) => { - short_circuit(subs, root_var, seen, once(&ext.var())) + short_circuit(subs, root_var, ctx, once(&ext.var())) } RecursiveTagUnion(_, tags, ext) => { - occurs_union(subs, root_var, seen, tags)?; + let ext_var = ext.var(); + occurs_union(subs, root_var, ctx, safe!(UnionLabels, tags))?; - short_circuit_help(subs, root_var, seen, ext.var()) + short_circuit_help(subs, root_var, ctx, ext_var) } EmptyRecord | EmptyTuple | EmptyTagUnion => Ok(()), }, Alias(_, args, real_var, _) => { + let real_var = *real_var; for var_index in args.into_iter() { let var = subs[var_index]; - if short_circuit_help(subs, root_var, seen, var).is_err() { + if short_circuit_help(subs, root_var, ctx, var).is_err() { // Pay the cost and figure out what the actual recursion point is - return short_circuit_help(subs, root_var, seen, *real_var); + return short_circuit_help(subs, root_var, ctx, real_var); } } @@ -3527,27 +3580,27 @@ fn occurs( // unspecialized lambda vars excluded because they are not explicitly part of the // type (they only matter after being resolved). - occurs_union(subs, root_var, seen, solved) + occurs_union(subs, root_var, ctx, safe!(UnionLabels, solved)) } RangedNumber(_range_vars) => Ok(()), })(); - seen.pop(); + ctx.seen.pop(); result } } #[inline(always)] fn occurs_union( - subs: &Subs, + subs: &mut Subs, root_var: Variable, - seen: &mut Vec, + ctx: &mut OccursScratchpad, tags: &UnionLabels, ) -> Result<(), (Variable, Vec)> { for slice_index in tags.variables() { let slice = subs[slice_index]; for var_index in slice { let var = subs[var_index]; - short_circuit_help(subs, root_var, seen, var)?; + short_circuit_help(subs, root_var, ctx, var)?; } } Ok(()) @@ -3555,16 +3608,16 @@ fn occurs_union( #[inline(always)] fn short_circuit<'a, T>( - subs: &Subs, + subs: &mut Subs, root_key: Variable, - seen: &mut Vec, + ctx: &mut OccursScratchpad, iter: T, ) -> Result<(), (Variable, Vec)> where T: Iterator, { for var in iter { - short_circuit_help(subs, root_key, seen, *var)?; + short_circuit_help(subs, root_key, ctx, *var)?; } Ok(()) @@ -3572,12 +3625,12 @@ where #[inline(always)] fn short_circuit_help( - subs: &Subs, + subs: &mut Subs, root_key: Variable, - seen: &mut Vec, + ctx: &mut OccursScratchpad, var: Variable, ) -> Result<(), (Variable, Vec)> { - if let Err((v, mut vec)) = occurs(subs, seen, var) { + if let Err((v, mut vec)) = occurs(subs, ctx, var) { vec.push(root_key); return Err((v, vec)); }