diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 8e5c5a5a1e..baf8576c49 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -4220,10 +4220,14 @@ pub fn with_hole<'a>( // a proc in this module, or an imported symbol procs.partial_procs.contains_key(key) || (env.is_imported_symbol(key) && !procs.is_imported_module_thunk(key)) + || env.abilities_store.is_ability_member_name(key) }; match loc_expr.value { roc_can::expr::Expr::Var(proc_name) if is_known(proc_name) => { + // This might be an ability member - if so, use the appropriate specialization. + let proc_name = repoint_to_specialization(env, fn_var, proc_name); + // a call by a known name call_by_name( env, @@ -4707,6 +4711,43 @@ pub fn with_hole<'a>( } } +#[inline(always)] +fn repoint_to_specialization<'a>( + env: &mut Env<'a, '_>, + symbol_var: Variable, + symbol: Symbol, +) -> Symbol { + use roc_solve::ability::type_implementing_member; + use roc_unify::unify::unify; + + match env.abilities_store.member_def(symbol) { + None => { + // This is not an ability member, it doesn't need specialization. + symbol + } + Some(member) => { + let snapshot = env.subs.snapshot(); + let (_, must_implement_ability) = unify( + env.subs, + symbol_var, + member.signature_var, + roc_unify::unify::Mode::EQ, + ) + .expect_success("This typechecked previously"); + env.subs.rollback_to(snapshot); + let specializing_type = + type_implementing_member(&must_implement_ability, member.parent_ability); + + let specialization = env + .abilities_store + .get_specialization(symbol, specializing_type) + .expect("No specialization is recorded - I thought there would only be a type error here."); + + specialization.symbol + } + } +} + #[allow(clippy::too_many_arguments)] fn construct_closure_data<'a>( env: &mut Env<'a, '_>, diff --git a/compiler/solve/src/ability.rs b/compiler/solve/src/ability.rs index 3aaf27d32d..edbbd54bec 100644 --- a/compiler/solve/src/ability.rs +++ b/compiler/solve/src/ability.rs @@ -1,4 +1,5 @@ use roc_can::abilities::AbilitiesStore; +use roc_module::symbol::Symbol; use roc_region::all::{Loc, Region}; use roc_types::subs::Subs; use roc_types::subs::Variable; @@ -154,3 +155,23 @@ impl DeferredMustImplementAbility { problems } } + +/// Determines what type implements an ability member of a specialized signature, given the +/// [MustImplementAbility] constraints of the signature. +pub fn type_implementing_member( + specialization_must_implement_constraints: &[MustImplementAbility], + ability: Symbol, +) -> Symbol { + let mut ability_implementations_for_specialization = specialization_must_implement_constraints + .iter() + .filter(|mia| mia.ability == ability) + .collect::>(); + ability_implementations_for_specialization.dedup(); + + debug_assert!(ability_implementations_for_specialization.len() == 1, "Multiple variables bound to an ability - this is ambiguous and should have been caught in canonicalization"); + + ability_implementations_for_specialization + .pop() + .unwrap() + .typ +} diff --git a/compiler/solve/src/lib.rs b/compiler/solve/src/lib.rs index 06f9e2fd5c..d0e42eaf42 100644 --- a/compiler/solve/src/lib.rs +++ b/compiler/solve/src/lib.rs @@ -2,6 +2,6 @@ // See github.com/rtfeldman/roc/issues/800 for discussion of the large_enum_variant check. #![allow(clippy::large_enum_variant)] -mod ability; +pub mod ability; pub mod module; pub mod solve; diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index a466c99f8c..61814b7282 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -1,4 +1,4 @@ -use crate::ability::{AbilityImplError, DeferredMustImplementAbility}; +use crate::ability::{type_implementing_member, AbilityImplError, DeferredMustImplementAbility}; use bumpalo::Bump; use roc_can::abilities::{AbilitiesStore, MemberSpecialization}; use roc_can::constraint::Constraint::{self, *}; @@ -1346,16 +1346,8 @@ fn check_ability_specialization( // First, figure out and register for what type does this symbol specialize // the ability member. - let mut ability_implementations_for_specialization = must_implement_ability - .iter() - .filter(|mia| mia.ability == root_data.parent_ability) - .collect::>(); - ability_implementations_for_specialization.dedup(); - - debug_assert!(ability_implementations_for_specialization.len() == 1, "Multiple variables bound to an ability - this is ambiguous and should have been caught in canonicalization"); - - // This is a valid specialization! Record it. - let specialization_type = ability_implementations_for_specialization[0].typ; + let specialization_type = + type_implementing_member(&must_implement_ability, root_data.parent_ability); let specialization = MemberSpecialization { symbol, region: symbol_loc_var.region, diff --git a/compiler/test_mono/generated/specialize_ability_call.txt b/compiler/test_mono/generated/specialize_ability_call.txt new file mode 100644 index 0000000000..9e3179be06 --- /dev/null +++ b/compiler/test_mono/generated/specialize_ability_call.txt @@ -0,0 +1,7 @@ +procedure Test.5 (Test.8): + ret Test.8; + +procedure Test.0 (): + let Test.10 : U64 = 1234i64; + let Test.9 : U64 = CallByName Test.5 Test.10; + ret Test.9; diff --git a/compiler/test_mono/src/tests.rs b/compiler/test_mono/src/tests.rs index 74e17f9a15..bba051e6b6 100644 --- a/compiler/test_mono/src/tests.rs +++ b/compiler/test_mono/src/tests.rs @@ -1294,6 +1294,25 @@ fn issue_2811() { ) } +#[mono_test] +fn specialize_ability_call() { + indoc!( + r#" + app "test" provides [ main ] to "./platform" + + Hash has + hash : a -> U64 | a has Hash + + Id := U64 + + hash : Id -> U64 + hash = \$Id n -> n + + main = hash ($Id 1234) + "# + ) +} + // #[ignore] // #[mono_test] // fn static_str_closure() { diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index a586cdaed7..35685d8a4c 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -140,6 +140,18 @@ pub enum Unified { BadType(Pool, roc_types::types::Problem), } +impl Unified { + pub fn expect_success(self, err_msg: &'static str) -> (Pool, Vec) { + match self { + Unified::Success { + vars, + must_implement_ability, + } => (vars, must_implement_ability), + _ => panic!("{}", err_msg), + } + } +} + /// Specifies that `type` must implement the ability `ability`. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct MustImplementAbility {