From 328eea8b252078f75eef97ee3d22ffda33a3887a Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Wed, 22 Jun 2022 17:19:01 -0400 Subject: [PATCH] Add extra metadata parameter to unification --- ast/src/solve_type.rs | 4 + compiler/late_solve/src/lib.rs | 1 + compiler/solve/src/ability.rs | 2 +- compiler/solve/src/solve.rs | 9 +- compiler/unify/src/unify.rs | 175 ++++++++++++++++++++++----------- 5 files changed, 130 insertions(+), 61 deletions(-) diff --git a/ast/src/solve_type.rs b/ast/src/solve_type.rs index 9c87d37801..8399158e17 100644 --- a/ast/src/solve_type.rs +++ b/ast/src/solve_type.rs @@ -232,6 +232,7 @@ fn solve<'a>( vars, must_implement_ability: _, lambda_sets_to_specialize: _, // TODO ignored + extra_metadata: _, } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); @@ -330,6 +331,7 @@ fn solve<'a>( vars, must_implement_ability: _, lambda_sets_to_specialize: _, // TODO ignored + extra_metadata: _, } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); @@ -406,6 +408,7 @@ fn solve<'a>( vars, must_implement_ability: _, lambda_sets_to_specialize: _, // TODO ignored + extra_metadata: _, } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); @@ -719,6 +722,7 @@ fn solve<'a>( vars, must_implement_ability: _, lambda_sets_to_specialize: _, // TODO ignored + extra_metadata: _, } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); diff --git a/compiler/late_solve/src/lib.rs b/compiler/late_solve/src/lib.rs index 012c96d7f8..399d35299c 100644 --- a/compiler/late_solve/src/lib.rs +++ b/compiler/late_solve/src/lib.rs @@ -145,6 +145,7 @@ pub fn unify( vars: _, must_implement_ability: _, lambda_sets_to_specialize, + extra_metadata: _, } => { let mut pools = Pools::default(); diff --git a/compiler/solve/src/ability.rs b/compiler/solve/src/ability.rs index b0aeb8ab27..1e7bc75e29 100644 --- a/compiler/solve/src/ability.rs +++ b/compiler/solve/src/ability.rs @@ -652,7 +652,7 @@ pub fn resolve_ability_specialization( let signature_var = member_def.signature_var(); instantiate_rigids(subs, signature_var); - let (_vars, must_implement_ability, _lambda_sets_to_specialize) = + let (_vars, must_implement_ability, _lambda_sets_to_specialize, _meta) = unify(subs, specialization_var, signature_var, Mode::EQ).expect_success( "If resolving a specialization, the specialization must be known to typecheck.", ); diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index 6b539bdefc..aced1f6294 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -890,6 +890,7 @@ fn solve( vars, must_implement_ability, lambda_sets_to_specialize, + extra_metadata: _, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -944,6 +945,7 @@ fn solve( // ERROR NOT REPORTED must_implement_ability: _, lambda_sets_to_specialize, + extra_metadata: _, } => { introduce(subs, rank, pools, &vars); @@ -1002,6 +1004,7 @@ fn solve( vars, must_implement_ability, lambda_sets_to_specialize, + extra_metadata: _, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -1081,6 +1084,7 @@ fn solve( vars, must_implement_ability, lambda_sets_to_specialize, + extra_metadata: _, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -1245,6 +1249,7 @@ fn solve( vars, must_implement_ability, lambda_sets_to_specialize, + extra_metadata: _, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -1351,6 +1356,7 @@ fn solve( vars, must_implement_ability, lambda_sets_to_specialize, + extra_metadata: _, } => { subs.commit_snapshot(snapshot); @@ -1609,6 +1615,7 @@ fn check_ability_specialization( vars, must_implement_ability, lambda_sets_to_specialize, + extra_metadata: _, } => { let specialization_type = type_implementing_specialization(&must_implement_ability, parent_ability); @@ -1953,7 +1960,7 @@ fn compact_lambda_set( subs.set_content(this_lambda_set, partial_compacted_lambda_set); for other_specialized in specialized_to_unify_with.into_iter() { - let (vars, must_implement_ability, lambda_sets_to_specialize) = + let (vars, must_implement_ability, lambda_sets_to_specialize, _meta) = unify(subs, this_lambda_set, other_specialized, Mode::EQ) .expect_success("lambda sets don't unify"); diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index ca7cbdc3e5..08d0642de7 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -138,28 +138,62 @@ pub struct Context { mode: Mode, } +pub trait MetaCollector: Default + std::fmt::Debug { + /// Whether we are performing `member ~ specialization` where `member` is an ability member + /// signature and `specialization` is an ability specialization for a given type. When this is + /// the case, given a lambda set unification like + /// `[[] + a:member:1] ~ [specialization-lambda-set]`, only the specialization lambda set will + /// be kept around, and the record `(member, 1) => specialization-lambda-set` will be + /// associated via [`Self::record_specialization_lambda_set`]. + const UNIFYING_SPECIALIZATION: bool; + + fn record_specialization_lambda_set(&mut self, member: Symbol, region: u8, var: Variable); + + fn union(&mut self, other: Self); +} + +#[derive(Default, Debug)] +pub struct NoCollector; +impl MetaCollector for NoCollector { + const UNIFYING_SPECIALIZATION: bool = false; + + fn record_specialization_lambda_set(&mut self, _member: Symbol, _region: u8, _var: Variable) {} + + fn union(&mut self, _other: Self) {} +} + #[derive(Debug)] -pub enum Unified { +pub enum Unified { Success { vars: Pool, must_implement_ability: MustImplementConstraints, lambda_sets_to_specialize: UlsOfVar, + + /// The vast majority of the time the extra metadata is empty, so we make unification + /// polymorphic over metadata collection to avoid unnecessary memory usage. + extra_metadata: M, }, Failure(Pool, ErrorType, ErrorType, DoesNotImplementAbility), BadType(Pool, roc_types::types::Problem), } -impl Unified { +impl Unified { pub fn expect_success( self, err_msg: &'static str, - ) -> (Pool, MustImplementConstraints, UlsOfVar) { + ) -> (Pool, MustImplementConstraints, UlsOfVar, M) { match self { Unified::Success { vars, must_implement_ability, lambda_sets_to_specialize, - } => (vars, must_implement_ability, lambda_sets_to_specialize), + extra_metadata, + } => ( + vars, + must_implement_ability, + lambda_sets_to_specialize, + extra_metadata, + ), _ => internal_error!("{}", err_msg), } } @@ -212,7 +246,7 @@ impl MustImplementConstraints { } #[derive(Debug, Default)] -pub struct Outcome { +pub struct Outcome { mismatches: Vec, /// We defer these checks until the end of a solving phase. /// NOTE: this vector is almost always empty! @@ -220,25 +254,38 @@ pub struct Outcome { /// We defer resolution of these lambda sets to the caller of [unify]. /// See also [merge_flex_able_with_concrete]. lambda_sets_to_specialize: UlsOfVar, + extra_metadata: M, } -impl Outcome { +impl Outcome { fn union(&mut self, other: Self) { self.mismatches.extend(other.mismatches); self.must_implement_ability .extend(other.must_implement_ability); self.lambda_sets_to_specialize .union(other.lambda_sets_to_specialize); + self.extra_metadata.union(other.extra_metadata); } } #[inline(always)] pub fn unify(subs: &mut Subs, var1: Variable, var2: Variable, mode: Mode) -> Unified { + unify_help(subs, var1, var2, mode) +} + +#[inline(always)] +fn unify_help( + subs: &mut Subs, + var1: Variable, + var2: Variable, + mode: Mode, +) -> Unified { let mut vars = Vec::new(); let Outcome { mismatches, must_implement_ability, lambda_sets_to_specialize, + extra_metadata, } = unify_pool(subs, &mut vars, var1, var2, mode); if mismatches.is_empty() { @@ -246,6 +293,7 @@ pub fn unify(subs: &mut Subs, var1: Variable, var2: Variable, mode: Mode) -> Uni vars, must_implement_ability, lambda_sets_to_specialize, + extra_metadata, } } else { let error_context = if mismatches.contains(&Mismatch::TypeNotInRange) { @@ -282,13 +330,13 @@ pub fn unify(subs: &mut Subs, var1: Variable, var2: Variable, mode: Mode) -> Uni } #[inline(always)] -pub fn unify_pool( +pub fn unify_pool( subs: &mut Subs, pool: &mut Pool, var1: Variable, var2: Variable, mode: Mode, -) -> Outcome { +) -> Outcome { if subs.equivalent(var1, var2) { Outcome::default() } else { @@ -308,7 +356,11 @@ pub fn unify_pool( /// a tree to stderr. /// NOTE: Only run this on individual tests! Run on multiple threads, this would clobber each others' output. #[cfg(debug_assertions)] -fn debug_print_unified_types(subs: &mut Subs, ctx: &Context, opt_outcome: Option<&Outcome>) { +fn debug_print_unified_types( + subs: &mut Subs, + ctx: &Context, + opt_outcome: Option<&Outcome>, +) { use roc_types::subs::SubsFmtContent; static mut UNIFICATION_DEPTH: usize = 0; @@ -358,9 +410,9 @@ fn debug_print_unified_types(subs: &mut Subs, ctx: &Context, opt_outcome: Option }) } -fn unify_context(subs: &mut Subs, pool: &mut Pool, ctx: Context) -> Outcome { +fn unify_context(subs: &mut Subs, pool: &mut Pool, ctx: Context) -> Outcome { #[cfg(debug_assertions)] - debug_print_unified_types(subs, &ctx, None); + debug_print_unified_types::(subs, &ctx, None); // This #[allow] is needed in release builds, where `result` is no longer used. #[allow(clippy::let_and_return)] @@ -408,13 +460,13 @@ fn unify_context(subs: &mut Subs, pool: &mut Pool, ctx: Context) -> Outcome { } #[inline(always)] -fn unify_ranged_number( +fn unify_ranged_number( subs: &mut Subs, pool: &mut Pool, ctx: &Context, real_var: Variable, range_vars: NumericRange, -) -> Outcome { +) -> Outcome { let other_content = &ctx.second_desc.content; let outcome = match other_content { @@ -448,7 +500,11 @@ fn unify_ranged_number( check_valid_range(subs, ctx.second, range_vars) } -fn check_valid_range(subs: &mut Subs, var: Variable, range: NumericRange) -> Outcome { +fn check_valid_range( + subs: &mut Subs, + var: Variable, + range: NumericRange, +) -> Outcome { let content = subs.get_content_without_compacting(var); match content { @@ -463,6 +519,7 @@ fn check_valid_range(subs: &mut Subs, var: Variable, range: NumericRange) -> Out mismatches: vec![Mismatch::TypeNotInRange], must_implement_ability: Default::default(), lambda_sets_to_specialize: Default::default(), + extra_metadata: Default::default(), }; return outcome; @@ -485,7 +542,7 @@ fn check_valid_range(subs: &mut Subs, var: Variable, range: NumericRange) -> Out #[inline(always)] #[allow(clippy::too_many_arguments)] -fn unify_two_aliases( +fn unify_two_aliases( subs: &mut Subs, pool: &mut Pool, ctx: &Context, @@ -496,7 +553,7 @@ fn unify_two_aliases( other_args: AliasVariables, other_real_var: Variable, other_content: &Content, -) -> Outcome { +) -> Outcome { if args.len() == other_args.len() { let mut outcome = Outcome::default(); let it = args @@ -534,14 +591,14 @@ fn unify_two_aliases( // Unifies a structural alias #[inline(always)] -fn unify_alias( +fn unify_alias( subs: &mut Subs, pool: &mut Pool, ctx: &Context, symbol: Symbol, args: AliasVariables, real_var: Variable, -) -> Outcome { +) -> Outcome { let other_content = &ctx.second_desc.content; let kind = AliasKind::Structural; @@ -588,14 +645,14 @@ fn unify_alias( } #[inline(always)] -fn unify_opaque( +fn unify_opaque( subs: &mut Subs, pool: &mut Pool, ctx: &Context, symbol: Symbol, args: AliasVariables, real_var: Variable, -) -> Outcome { +) -> Outcome { let other_content = &ctx.second_desc.content; let kind = AliasKind::Opaque; @@ -655,13 +712,13 @@ fn unify_opaque( } #[inline(always)] -fn unify_structure( +fn unify_structure( subs: &mut Subs, pool: &mut Pool, ctx: &Context, flat_type: &FlatType, other: &Content, -) -> Outcome { +) -> Outcome { match other { FlexVar(_) => { // If the other is flex, Structure wins! @@ -793,13 +850,13 @@ fn unify_structure( } #[inline(always)] -fn unify_lambda_set( +fn unify_lambda_set( subs: &mut Subs, pool: &mut Pool, ctx: &Context, lambda_set: LambdaSet, other: &Content, -) -> Outcome { +) -> Outcome { match other { FlexVar(_) => merge(subs, ctx, Content::LambdaSet(lambda_set)), Content::LambdaSet(other_lambda_set) => { @@ -818,13 +875,13 @@ fn unify_lambda_set( } } -fn unify_lambda_set_help( +fn unify_lambda_set_help( subs: &mut Subs, pool: &mut Pool, ctx: &Context, lset1: self::LambdaSet, lset2: self::LambdaSet, -) -> Outcome { +) -> Outcome { // LambdaSets unify like TagUnions, but can grow unbounded regardless of the extension // variable. @@ -876,7 +933,7 @@ fn unify_lambda_set_help( maybe_mark_union_recursive(subs, var1); maybe_mark_union_recursive(subs, var2); - let outcome = unify_pool(subs, pool, var1, var2, ctx.mode); + let outcome = unify_pool::(subs, pool, var1, var2, ctx.mode); if outcome.mismatches.is_empty() { matching_vars.push(var1); @@ -980,12 +1037,12 @@ fn unify_lambda_set_help( // resolve these cases here. // // See tests labeled "issue_2810" for more examples. -fn fix_tag_union_recursion_variable( +fn fix_tag_union_recursion_variable( subs: &mut Subs, ctx: &Context, tag_union_promoted_to_recursive: Variable, recursion_var: &Content, -) -> Outcome { +) -> Outcome { debug_assert!(matches!( subs.get_content_without_compacting(tag_union_promoted_to_recursive), Structure(FlatType::RecursiveTagUnion(..)) @@ -1002,7 +1059,7 @@ fn fix_tag_union_recursion_variable( } } -fn unify_record( +fn unify_record( subs: &mut Subs, pool: &mut Pool, ctx: &Context, @@ -1010,7 +1067,7 @@ fn unify_record( ext1: Variable, fields2: RecordFields, ext2: Variable, -) -> Outcome { +) -> Outcome { let (separate, ext1, ext2) = separate_record_fields(subs, fields1, ext1, fields2, ext2); let shared_fields = separate.in_both; @@ -1118,14 +1175,14 @@ enum OtherFields { type SharedFields = Vec<(Lowercase, (RecordField, RecordField))>; -fn unify_shared_fields( +fn unify_shared_fields( subs: &mut Subs, pool: &mut Pool, ctx: &Context, shared_fields: SharedFields, other_fields: OtherFields, ext: Variable, -) -> Outcome { +) -> Outcome { let mut matching_fields = Vec::with_capacity(shared_fields.len()); let num_shared_fields = shared_fields.len(); @@ -1375,7 +1432,7 @@ enum Rec { } #[allow(clippy::too_many_arguments)] -fn unify_tag_unions( +fn unify_tag_unions( subs: &mut Subs, pool: &mut Pool, ctx: &Context, @@ -1384,7 +1441,7 @@ fn unify_tag_unions( tags2: UnionTags, initial_ext2: Variable, recursion_var: Rec, -) -> Outcome { +) -> Outcome { let (separate, mut ext1, ext2) = separate_union_tags(subs, tags1, initial_ext1, tags2, initial_ext2); @@ -1593,7 +1650,7 @@ fn maybe_mark_union_recursive(subs: &mut Subs, union_var: Variable) { } } -fn unify_shared_tags_new( +fn unify_shared_tags_new( subs: &mut Subs, pool: &mut Pool, ctx: &Context, @@ -1601,7 +1658,7 @@ fn unify_shared_tags_new( other_tags: OtherTags2, ext: Variable, recursion_var: Rec, -) -> Outcome { +) -> Outcome { let mut matching_tags = Vec::default(); let num_shared_tags = shared_tags.len(); @@ -1646,7 +1703,7 @@ fn unify_shared_tags_new( maybe_mark_union_recursive(subs, actual); maybe_mark_union_recursive(subs, expected); - let mut outcome = Outcome::default(); + let mut outcome = Outcome::::default(); outcome.union(unify_pool(subs, pool, actual, expected, ctx.mode)); @@ -1719,13 +1776,13 @@ fn unify_shared_tags_new( } } -fn unify_shared_tags_merge_new( +fn unify_shared_tags_merge_new( subs: &mut Subs, ctx: &Context, new_tags: UnionTags, new_ext_var: Variable, recursion_var: Rec, -) -> Outcome { +) -> Outcome { let flat_type = match recursion_var { Rec::None => FlatType::TagUnion(new_tags, new_ext_var), Rec::Left(rec) | Rec::Right(rec) | Rec::Both(rec, _) => { @@ -1738,13 +1795,13 @@ fn unify_shared_tags_merge_new( } #[inline(always)] -fn unify_flat_type( +fn unify_flat_type( subs: &mut Subs, pool: &mut Pool, ctx: &Context, left: &FlatType, right: &FlatType, -) -> Outcome { +) -> Outcome { use roc_types::subs::FlatType::*; match (left, right) { @@ -1924,12 +1981,12 @@ fn unify_flat_type( } } -fn unify_zip_slices( +fn unify_zip_slices( subs: &mut Subs, pool: &mut Pool, left: SubsSlice, right: SubsSlice, -) -> Outcome { +) -> Outcome { let mut outcome = Outcome::default(); let it = left.into_iter().zip(right.into_iter()); @@ -1945,12 +2002,12 @@ fn unify_zip_slices( } #[inline(always)] -fn unify_rigid( +fn unify_rigid( subs: &mut Subs, ctx: &Context, name: &SubsIndex, other: &Content, -) -> Outcome { +) -> Outcome { match other { FlexVar(_) => { // If the other is flex, rigid wins! @@ -1985,13 +2042,13 @@ fn unify_rigid( } #[inline(always)] -fn unify_rigid_able( +fn unify_rigid_able( subs: &mut Subs, ctx: &Context, name: &SubsIndex, ability: Symbol, other: &Content, -) -> Outcome { +) -> Outcome { match other { FlexVar(_) => { // If the other is flex, rigid wins! @@ -2034,12 +2091,12 @@ fn unify_rigid_able( } #[inline(always)] -fn unify_flex( +fn unify_flex( subs: &mut Subs, ctx: &Context, opt_name: &Option>, other: &Content, -) -> Outcome { +) -> Outcome { match other { FlexVar(other_opt_name) => { // Prefer using right's name. @@ -2070,13 +2127,13 @@ fn unify_flex( } #[inline(always)] -fn unify_flex_able( +fn unify_flex_able( subs: &mut Subs, ctx: &Context, opt_name: &Option>, ability: Symbol, other: &Content, -) -> Outcome { +) -> Outcome { match other { FlexVar(opt_other_name) => { // Prefer using right's name. @@ -2147,14 +2204,14 @@ fn unify_flex_able( } } -fn merge_flex_able_with_concrete( +fn merge_flex_able_with_concrete( subs: &mut Subs, ctx: &Context, flex_able_var: Variable, ability: Symbol, concrete_content: Content, concrete_obligation: Obligated, -) -> Outcome { +) -> Outcome { let mut outcome = merge(subs, ctx, concrete_content); let must_implement_ability = MustImplementAbility { typ: concrete_obligation, @@ -2178,14 +2235,14 @@ fn merge_flex_able_with_concrete( } #[inline(always)] -fn unify_recursion( +fn unify_recursion( subs: &mut Subs, pool: &mut Pool, ctx: &Context, opt_name: &Option>, structure: Variable, other: &Content, -) -> Outcome { +) -> Outcome { match other { RecursionVar { opt_name: other_opt_name, @@ -2255,7 +2312,7 @@ fn unify_recursion( } } -pub fn merge(subs: &mut Subs, ctx: &Context, content: Content) -> Outcome { +pub fn merge(subs: &mut Subs, ctx: &Context, content: Content) -> Outcome { let rank = ctx.first_desc.rank.min(ctx.second_desc.rank); let desc = Descriptor { content, @@ -2298,7 +2355,7 @@ fn is_recursion_var(subs: &Subs, var: Variable) -> bool { } #[allow(clippy::too_many_arguments)] -fn unify_function_or_tag_union_and_func( +fn unify_function_or_tag_union_and_func( subs: &mut Subs, pool: &mut Pool, ctx: &Context, @@ -2309,7 +2366,7 @@ fn unify_function_or_tag_union_and_func( function_return: Variable, function_lambda_set: Variable, left: bool, -) -> Outcome { +) -> Outcome { let tag_name = subs[*tag_name_index].clone(); let union_tags = UnionTags::insert_slices_into_subs(subs, [(tag_name, function_arguments)]);