diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index 11c47811e6..5e1e9a9a68 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -3045,27 +3045,35 @@ fn deep_copy_var_in( ) -> Variable { let mut visited = bumpalo::collections::Vec::with_capacity_in(256, arena); - let copy = deep_copy_var_help(subs, rank, pools, &mut visited, var); + let pool = pools.get_mut(rank); + let copy = deep_copy_var_help(subs, rank, pool, &mut visited, var); // we have tracked all visited variables, and can now traverse them // in one go (without looking at the UnificationTable) and clear the copy field for var in visited { - subs.modify(var, |descriptor| { - if descriptor.copy.is_some() { - descriptor.rank = Rank::NONE; - descriptor.mark = Mark::NONE; - descriptor.copy = OptVariable::NONE; - } - }); + subs.set_copy_unchecked(var, OptVariable::NONE); } copy } +#[inline] +fn has_trivial_copy(subs: &Subs, root_var: Variable) -> Option { + let existing_copy = subs.get_copy_unchecked(root_var); + + if let Some(copy) = existing_copy.into_variable() { + Some(copy) + } else if subs.get_rank_unchecked(root_var) != Rank::NONE { + Some(root_var) + } else { + None + } +} + fn deep_copy_var_help( subs: &mut Subs, max_rank: Rank, - pools: &mut Pools, + pool: &mut Vec, visited: &mut bumpalo::collections::Vec<'_, Variable>, var: Variable, ) -> Variable { @@ -3073,44 +3081,39 @@ fn deep_copy_var_help( use roc_types::subs::FlatType::*; let subs_len = subs.len(); + let var = subs.get_root_key(var); - let existing_copy = subs.get_copy(var); - - if let Some(copy) = existing_copy.into_variable() { + // either this variable has been copied before, or does not have NONE rank + if let Some(copy) = has_trivial_copy(subs, var) { return copy; - } else if subs.get_rank(var) != Rank::NONE { - return var; } - visited.push(var); - - let make_descriptor = |content| Descriptor { - content, - rank: max_rank, - mark: Mark::NONE, - copy: OptVariable::NONE, - }; - // Safety: Here we make a variable that is 1 position out of bounds. // The reason is that we can now keep the mutable reference to `desc` // Below, we actually push a new variable onto subs meaning the `copy` // variable is in-bounds before it is ever used. let copy = unsafe { Variable::from_index(subs_len as u32) }; - pools.get_mut(max_rank).push(copy); + visited.push(var); + pool.push(copy); // Link the original variable to the new variable. This lets us // avoid making multiple copies of the variable we are instantiating. // // Need to do this before recursively copying to avoid looping. - subs.modify(var, |desc| { - desc.mark = Mark::NONE; - desc.copy = copy.into(); - }); + subs.set_mark_unchecked(var, Mark::NONE); + subs.set_copy_unchecked(var, copy.into()); - let content = *subs.get_content_without_compacting(var); + let content = *subs.get_content_unchecked(var); - let actual_copy = subs.fresh(make_descriptor(content)); + let copy_descriptor = Descriptor { + content, + rank: max_rank, + mark: Mark::NONE, + copy: OptVariable::NONE, + }; + + let actual_copy = subs.fresh(copy_descriptor); debug_assert_eq!(copy, actual_copy); macro_rules! copy_sequence { @@ -3118,7 +3121,7 @@ fn deep_copy_var_help( let new_variables = SubsSlice::reserve_into_subs(subs, $length as _); for (target_index, var_index) in (new_variables.indices()).zip($variables) { let var = subs[var_index]; - let copy_var = deep_copy_var_help(subs, max_rank, pools, visited, var); + let copy_var = deep_copy_var_help(subs, max_rank, pool, visited, var); subs.variables[target_index] = copy_var; } @@ -3139,9 +3142,9 @@ fn deep_copy_var_help( } Func(arguments, closure_var, ret_var) => { - let new_ret_var = deep_copy_var_help(subs, max_rank, pools, visited, ret_var); + let new_ret_var = deep_copy_var_help(subs, max_rank, pool, visited, ret_var); let new_closure_var = - deep_copy_var_help(subs, max_rank, pools, visited, closure_var); + deep_copy_var_help(subs, max_rank, pool, visited, closure_var); let new_arguments = copy_sequence!(arguments.len(), arguments); @@ -3164,7 +3167,7 @@ fn deep_copy_var_help( Record( record_fields, - deep_copy_var_help(subs, max_rank, pools, visited, ext_var), + deep_copy_var_help(subs, max_rank, pool, visited, ext_var), ) } @@ -3181,14 +3184,14 @@ fn deep_copy_var_help( let union_tags = UnionTags::from_slices(tags.tag_names(), new_variable_slices); - let new_ext = deep_copy_var_help(subs, max_rank, pools, visited, ext_var); + let new_ext = deep_copy_var_help(subs, max_rank, pool, visited, ext_var); TagUnion(union_tags, new_ext) } FunctionOrTagUnion(tag_name, symbol, ext_var) => FunctionOrTagUnion( tag_name, symbol, - deep_copy_var_help(subs, max_rank, pools, visited, ext_var), + deep_copy_var_help(subs, max_rank, pool, visited, ext_var), ), RecursiveTagUnion(rec_var, tags, ext_var) => { @@ -3204,14 +3207,14 @@ fn deep_copy_var_help( let union_tags = UnionTags::from_slices(tags.tag_names(), new_variable_slices); - let new_ext = deep_copy_var_help(subs, max_rank, pools, visited, ext_var); - let new_rec_var = deep_copy_var_help(subs, max_rank, pools, visited, rec_var); + let new_ext = deep_copy_var_help(subs, max_rank, pool, visited, ext_var); + let new_rec_var = deep_copy_var_help(subs, max_rank, pool, visited, rec_var); RecursiveTagUnion(new_rec_var, union_tags, new_ext) } }; - subs.set(copy, make_descriptor(Structure(new_flat_type))); + subs.set_content_unchecked(copy, Structure(new_flat_type)); copy } @@ -3222,27 +3225,26 @@ fn deep_copy_var_help( opt_name, structure, } => { - let new_structure = deep_copy_var_help(subs, max_rank, pools, visited, structure); + let new_structure = deep_copy_var_help(subs, max_rank, pool, visited, structure); - subs.set( - copy, - make_descriptor(RecursionVar { - opt_name, - structure: new_structure, - }), - ); + let content = RecursionVar { + opt_name, + structure: new_structure, + }; + + subs.set_content_unchecked(copy, content); copy } RigidVar(name) => { - subs.set(copy, make_descriptor(FlexVar(Some(name)))); + subs.set_content_unchecked(copy, FlexVar(Some(name))); copy } RigidAbleVar(name, ability) => { - subs.set(copy, make_descriptor(FlexAbleVar(Some(name), ability))); + subs.set_content_unchecked(copy, FlexAbleVar(Some(name), ability)); copy } @@ -3257,22 +3259,22 @@ fn deep_copy_var_help( }; let new_real_type_var = - deep_copy_var_help(subs, max_rank, pools, visited, real_type_var); + deep_copy_var_help(subs, max_rank, pool, visited, real_type_var); let new_content = Alias(symbol, new_arguments, new_real_type_var, kind); - subs.set(copy, make_descriptor(new_content)); + subs.set_content_unchecked(copy, new_content); copy } RangedNumber(typ, range_vars) => { - let new_type_var = deep_copy_var_help(subs, max_rank, pools, visited, typ); + let new_type_var = deep_copy_var_help(subs, max_rank, pool, visited, typ); let new_variables = copy_sequence!(range_vars.len(), range_vars); let new_content = RangedNumber(new_type_var, new_variables); - subs.set(copy, make_descriptor(new_content)); + subs.set_content_unchecked(copy, new_content); copy } diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index a3626481a4..da17f69d44 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -1659,6 +1659,10 @@ impl Subs { self.utable.get_rank_unchecked(key) } + pub fn get_copy_unchecked(&self, key: Variable) -> OptVariable { + self.utable.get_copy_unchecked(key) + } + #[inline(always)] pub fn get_without_compacting(&self, key: Variable) -> Descriptor { self.utable.get_descriptor(key) @@ -1702,6 +1706,10 @@ impl Subs { self.utable.set_mark_unchecked(key, mark) } + pub fn set_copy_unchecked(&mut self, key: Variable, copy: OptVariable) { + self.utable.set_copy_unchecked(key, copy) + } + pub fn set_copy(&mut self, key: Variable, copy: OptVariable) { self.utable.set_copy(key, copy) } @@ -1712,11 +1720,13 @@ impl Subs { } pub fn set_content(&mut self, key: Variable, content: Content) { - // let l_key = self.utable.inlined_get_root_key(key); - self.utable.set_content(key, content); } + pub fn set_content_unchecked(&mut self, key: Variable, content: Content) { + self.utable.set_content_unchecked(key, content); + } + pub fn modify(&mut self, key: Variable, mapper: F) -> T where F: FnOnce(&mut Descriptor) -> T,