diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index ca71ab56cd..726aab1e38 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -4817,3 +4817,90 @@ fn copy_import_to_help(env: &mut CopyImportEnv<'_>, max_rank: Rank, var: Variabl } } } + +pub trait ContentVisitor<'a> { + fn subs(&self) -> &'a Subs; + + fn insert_recursion_var(&mut self, var: Variable); + fn seen_recursion_var(&self, var: Variable) -> bool; + + /// Visits the `content` for `var` and returns whether or not descend into `content` further. + /// Don't call this directly; `visit` will call this. + fn handle_content(&mut self, var: Variable, content: &Content) -> bool; + + /// Entry point to the visitor; descends the content of a variable until its end, or a cycle. + /// Attach handlers for different content using `handle_content`. You likely don't want to call + /// `visit` directly, as it uses an iterative rather than recursive visitor. + fn visit(&mut self, var: Variable) { + let mut stack = vec![var]; + + let subs = self.subs(); + + macro_rules! push_var_slice { + ($slice:expr) => { + stack.extend(subs.get_subs_slice($slice)) + }; + } + + while let Some(var) = stack.pop() { + if self.seen_recursion_var(var) { + continue; + } + + let content = subs.get_content_without_compacting(var); + + if !self.handle_content(var, content) { + continue; + } + + use Content::*; + use FlatType::*; + match content { + FlexVar(_) | RigidVar(_) | FlexAbleVar(_, _) | RigidAbleVar(_, _) => {} + RecursionVar { + structure, + opt_name: _, + } => { + self.insert_recursion_var(var); + stack.push(*structure); + } + Structure(flat_type) => match flat_type { + Apply(_, vars) => push_var_slice!(*vars), + Func(args, clos, ret) => { + push_var_slice!(*args); + stack.push(*clos); + stack.push(*ret); + } + Record(fields, var) => { + push_var_slice!(fields.variables()); + stack.push(*var); + } + TagUnion(tags, ext_var) => { + for i in tags.variables() { + push_var_slice!(subs[i]); + } + stack.push(*ext_var); + } + FunctionOrTagUnion(_, _, var) => stack.push(*var), + RecursiveTagUnion(rec_var, tags, ext_var) => { + self.insert_recursion_var(*rec_var); + for i in tags.variables() { + push_var_slice!(subs[i]); + } + stack.push(*ext_var); + } + Erroneous(_) | EmptyRecord | EmptyTagUnion => {} + }, + Alias(_, arguments, real_type_var, _) => { + push_var_slice!(arguments.all_variables()); + stack.push(*real_type_var); + } + RangedNumber(typ, vars) => { + stack.push(*typ); + push_var_slice!(*vars); + } + Error => {} + } + } + } +}